Source code for bigdl.contrib.onnx.onnx_loader

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import onnx
from bigdl.nn.onnx.layer import *
from bigdl.nn.layer import Identity, Model
from .ops_mapping import _convert_map as convert_map
from .converter_utils import parse_node_attr, parse_tensor_data


[docs]class OnnxLoader(object):
[docs] def load_model(self, file_path): model_proto = onnx.load_model(file_path) # self._ir_version = model_proto.ir_version # self._opset_import = model_proto.opset_import # self._producer_name = model_proto.producer_name # self._producer_version = model_proto.producer_version # self._domain = model_proto.domain # self._model_version = model_proto.model_version # self._doc_string = model_proto.doc_string graph_proto = model_proto.graph return self.load_graph(graph_proto)
[docs] def load_graph(self, graph_proto): if not graph_proto: raise ValueError("Graph proto is required") input_nodes = list() output_nodes = list() tensor_map = dict() initialized_tensors = set() module_map = dict() root_nodes = list() dummy_root = Identity()() for tensor in graph_proto.initializer: if not tensor.name.strip(): raise ValueError("Tensor's name is required") initialized_tensors.add(tensor.name) tensor_data = parse_tensor_data(tensor) tensor_map[tensor.name] = (tensor_data, tensor_data.shape) for gin in graph_proto.input: if gin.name not in initialized_tensors: input_nodes.append(gin.name) shape = tuple([dim.dim_value for dim in gin.type.tensor_type.shape.dim]) module_map[gin.name] = Identity()(dummy_root) tensor_map[gin.name] = (None, shape) for gout in graph_proto.output: if gout.name not in initialized_tensors: output_nodes.append(gout.name) for node in graph_proto.node: name = node.name.strip() op_type = node.op_type inputs = [tensor_map[n] for n in node.input] outputs = node.output prev_modules = [module_map[n] for n in node.input if n not in initialized_tensors] attrs = parse_node_attr(node) if len(prev_modules) == 0: root_nodes.append((name, op_type)) prev_modules = [dummy_root] bigdl_module, outputs_shape = self._make_module_from_onnx_node(op_type, inputs, prev_modules, attrs, outputs) assert len(outputs) == len(outputs_shape) for out, out_shape in zip(outputs, outputs_shape): module_map[out] = bigdl_module tensor_map[out] = (None, out_shape) in_modules = [module_map[m] for m in input_nodes] out_modules = [module_map[m] for m in output_nodes] model = Model([dummy_root], out_modules) return model
def _make_module_from_onnx_node(self, op_type, inputs, prev_modules, attrs, outputs): module = None out_shapes = [] if op_type in convert_map: module, out_shapes = convert_map[op_type](inputs, prev_modules, attrs, outputs) else: raise NotImplemented(op_type) return module, out_shapes
[docs]def load(model_path): loader = OnnxLoader() return loader.load_model(model_path)
[docs]def load_model_proto(model_proto): loader = OnnxLoader() return loader.load_graph(model_proto.graph)