Source code for bigdl.contrib.onnx.converter_utils
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import math
import numpy as np
[docs]def calc_output_shape(input, kernel, padding=0, stride=1, dilation=1, ceil_mode=False):
	def dilated_kernel_size(kernel, dilation):
		return kernel + (kernel - 1) * (dilation - 1)
	rounding = math.ceil if ceil_mode else math.floor
	out = (input + 2 * padding - dilated_kernel_size(kernel, dilation)) / stride + 1
	out = int(rounding(out))
	return out 
[docs]def parse_node_attr(node_proto):
	attrs = {}
	attr_proto = node_proto.attribute
	for attr in attr_proto:
		for field in ['f', 'i', 's']:
			if attr.HasField(field):
				attrs[attr.name] = getattr(attr, field)
				# Needed for supporting python version > 3.5
				if isinstance(attrs[attr.name], bytes):
					attrs[attr.name] = attrs[attr.name].decode(encoding='utf-8')
		for field in ['floats', 'ints', 'strings']:
			if list(getattr(attr, field)):
				assert attr.name not in attrs, "Only one type of attr is allowed"
				attrs[attr.name] = tuple(getattr(attr, field))
		for field in ['t', 'g']:
			if attr.HasField(field):
				attrs[attr.name] = getattr(attr, field)
		for field in ['tensors', 'graphs']:
			if list(getattr(attr, field)):
				raise NotImplementedError()
		if attr.name not in attrs:
			raise ValueError("Cannot parse attribute: \n{}\n.".format(attr))
	return attrs 
[docs]def parse_tensor_data(tensor_proto):
	try:
		from onnx.numpy_helper import to_array
	except ImportError:
		raise ImportError("Onnx and protobuf need to be installed.")
	if len(tuple(tensor_proto.dims)) > 0:
		np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
	else:
		# If it is a scalar tensor
		np_array = np.array([to_array(tensor_proto)])
	return np_array