Source code for bigdl.contrib.onnx.onnx_loader
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import onnx
from bigdl.nn.onnx.layer import *
from bigdl.nn.layer import Identity, Model
from .ops_mapping import _convert_map as convert_map
from .converter_utils import parse_node_attr, parse_tensor_data
[docs]class OnnxLoader(object):
[docs]    def load_model(self, file_path):
        model_proto = onnx.load_model(file_path)
        # self._ir_version = model_proto.ir_version
        # self._opset_import = model_proto.opset_import
        # self._producer_name = model_proto.producer_name
        # self._producer_version = model_proto.producer_version
        # self._domain = model_proto.domain
        # self._model_version = model_proto.model_version
        # self._doc_string = model_proto.doc_string
        graph_proto = model_proto.graph
        return self.load_graph(graph_proto) 
[docs]    def load_graph(self, graph_proto):
        if not graph_proto:
            raise ValueError("Graph proto is required")
        input_nodes = list()
        output_nodes = list()
        tensor_map = dict()
        initialized_tensors = set()
        module_map = dict()
        root_nodes = list()
        dummy_root = Identity()()
        for tensor in graph_proto.initializer:
            if not tensor.name.strip():
                raise ValueError("Tensor's name is required")
            initialized_tensors.add(tensor.name)
            tensor_data = parse_tensor_data(tensor)
            tensor_map[tensor.name] = (tensor_data, tensor_data.shape)
        for gin in graph_proto.input:
            if gin.name not in initialized_tensors:
                input_nodes.append(gin.name)
                shape = tuple([dim.dim_value for dim in gin.type.tensor_type.shape.dim])
                module_map[gin.name] = Identity()(dummy_root)
                tensor_map[gin.name] = (None, shape)
        for gout in graph_proto.output:
            if gout.name not in initialized_tensors:
                output_nodes.append(gout.name)
        for node in graph_proto.node:
            name = node.name.strip()
            op_type = node.op_type
            inputs = [tensor_map[n] for n in node.input]
            outputs = node.output
            prev_modules = [module_map[n] for n in node.input if n not in initialized_tensors]
            attrs = parse_node_attr(node)
            if len(prev_modules) == 0:
                root_nodes.append((name, op_type))
                prev_modules = [dummy_root]
            bigdl_module, outputs_shape = self._make_module_from_onnx_node(op_type, inputs, prev_modules, attrs, outputs)
            assert len(outputs) == len(outputs_shape)
            for out, out_shape in zip(outputs, outputs_shape):
                module_map[out] = bigdl_module
                tensor_map[out] = (None, out_shape)
        in_modules = [module_map[m] for m in input_nodes]
        out_modules = [module_map[m] for m in output_nodes]
        model = Model([dummy_root], out_modules)
        return model 
    def _make_module_from_onnx_node(self, op_type, inputs, prev_modules, attrs, outputs):
        module = None
        out_shapes = []
        if op_type in convert_map:
            module, out_shapes = convert_map[op_type](inputs, prev_modules, attrs, outputs)
        else:
            raise NotImplemented(op_type)
        return module, out_shapes 
[docs]def load(model_path):
    loader = OnnxLoader()
    return loader.load_model(model_path) 
[docs]def load_model_proto(model_proto):
    loader = OnnxLoader()
    return loader.load_graph(model_proto.graph)