Source code for bigdl.contrib.onnx.converter_utils
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import math
import numpy as np
[docs]def calc_output_shape(input, kernel, padding=0, stride=1, dilation=1, ceil_mode=False):
def dilated_kernel_size(kernel, dilation):
return kernel + (kernel - 1) * (dilation - 1)
rounding = math.ceil if ceil_mode else math.floor
out = (input + 2 * padding - dilated_kernel_size(kernel, dilation)) / stride + 1
out = int(rounding(out))
return out
[docs]def parse_node_attr(node_proto):
attrs = {}
attr_proto = node_proto.attribute
for attr in attr_proto:
for field in ['f', 'i', 's']:
if attr.HasField(field):
attrs[attr.name] = getattr(attr, field)
# Needed for supporting python version > 3.5
if isinstance(attrs[attr.name], bytes):
attrs[attr.name] = attrs[attr.name].decode(encoding='utf-8')
for field in ['floats', 'ints', 'strings']:
if list(getattr(attr, field)):
assert attr.name not in attrs, "Only one type of attr is allowed"
attrs[attr.name] = tuple(getattr(attr, field))
for field in ['t', 'g']:
if attr.HasField(field):
attrs[attr.name] = getattr(attr, field)
for field in ['tensors', 'graphs']:
if list(getattr(attr, field)):
raise NotImplementedError()
if attr.name not in attrs:
raise ValueError("Cannot parse attribute: \n{}\n.".format(attr))
return attrs
[docs]def parse_tensor_data(tensor_proto):
try:
from onnx.numpy_helper import to_array
except ImportError:
raise ImportError("Onnx and protobuf need to be installed.")
if len(tuple(tensor_proto.dims)) > 0:
np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
else:
# If it is a scalar tensor
np_array = np.array([to_array(tensor_proto)])
return np_array