#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from bigdl.nn.layer import SpatialAveragePooling, SpatialBatchNormalization
from bigdl.nn.layer import SpatialConvolution, SpatialMaxPooling, JoinTable
from bigdl.nn.layer import ReLU, SoftMax, CAddTable, Unsqueeze
from bigdl.nn.onnx.layer import Constant, Gather, Gemm, Shape, Reshape
from .converter_utils import *
[docs]def average_pool(inputs, prev_modules, attrs, outputs):
	# extract attributes
	auto_pad = attrs.get('auto_pad', 'NOTSET')
	ceil_mode = True if attrs.get('ceil_mode', 0) == 1 else False
	count_include_pad = True if attrs.get('count_include_pad', 0) == 1 else False
	kernel_width, kernel_height = map(int, attrs.get('kernel_shape', (1, 1))[:2])
	stride_width, stride_height = map(int, attrs.get('strides', (1, 1))[:2])
	padding_width, padding_height = map(int, attrs.get('pads', (0, 0))[:2])
	# extract inputs
	_, data_tensor_shape = inputs[0]
	# calc output tensor shape
	input_height, input_width = data_tensor_shape[-2:]
	output_height = calc_output_shape(input_height, kernel_height, 
		padding = padding_height, stride = stride_height, ceil_mode = ceil_mode)
	output_width = calc_output_shape(input_width, kernel_width, 
		padding = padding_width, stride = stride_width, ceil_mode = ceil_mode)
	out_tensor_shape = list(data_tensor_shape)
	out_tensor_shape[-2] = output_height
	out_tensor_shape[-1] = output_width
	out_tensor_shape = tuple(out_tensor_shape)
	# create module node
	module = SpatialAveragePooling(kw=kernel_width, kh=kernel_height,
				dw=stride_width, dh=stride_height,
				pad_w=padding_width, pad_h=padding_height,
				ceil_mode=ceil_mode, count_include_pad=count_include_pad
		)(prev_modules)
	return module, [out_tensor_shape] 
[docs]def batch_norm(inputs, prev_modules, attrs, outputs):
	# extract attributes
	epsilon = float(attrs.get('epsilon', 1e-05))
	momentum = float(attrs.get('momentum', 0.9))
	# extract inputs
	_, data_tensor_shape = inputs[0]
	scale_tensor_val, _ = inputs[1]
	bias_tensor_val, _ = inputs[2]
	mean_tensor_val, _ = inputs[3]
	var_tensor_val, _ = inputs[4]
	# calc output tensor shape
	out_tensor_shape = data_tensor_shape
	# create module node
	n_output = int(data_tensor_shape[1])
	temp_module = SpatialBatchNormalization(n_output=n_output, eps=epsilon,
		momentum=momentum, init_weight=scale_tensor_val, init_bias=bias_tensor_val)
	if mean_tensor_val is not None:
		temp_module.set_running_mean(mean_tensor_val)
	if var_tensor_val is not None:
		temp_module.set_running_std(var_tensor_val)
	module = temp_module(prev_modules[0])
	return module, [out_tensor_shape] 
[docs]def concat(inputs, prev_modules, attrs, outputs):
	# extract attributes
	axis = int(attrs.get('axis'))
	# extract inputs
	_, data_tensor_shape = inputs[0]
	# calc output tensor shape
	dim_rank = 0
	for i in range(len(inputs)):
		_, curr_input_shape = inputs[i]
		for j in range(len(data_tensor_shape)):
			if axis != j:
				if curr_input_shape[i] != data_tensor_shape[i]:
					raise ValueError("Input shape mismatch. Expect receive input shape " +
						data_tensor_shape[i] + " but got " + curr_input_shape[i])
			else:
				dim_rank += curr_input_shape[axis]
	out_tensor_shape = list(data_tensor_shape)
	out_tensor_shape[axis] = dim_rank
	out_tensor_shape = tuple(out_tensor_shape)
	# create module node
	module = JoinTable(dimension=axis+1, n_input_dims=len(data_tensor_shape))(prev_modules)
	return module, [out_tensor_shape] 
[docs]def constant(inputs, prev_modules, attrs, outputs):
	# extract attributes
	value = parse_tensor_data(attrs.get('value'))
	# calc output tensor shape
	out_tensor_shape = value.shape
	# create module node
	module = Constant(value)(prev_modules[0])
	return module, [out_tensor_shape] 
[docs]def conv(inputs, prev_modules, attrs, outputs):
	# extract attributes
	auto_pad = attrs.get('auto_pad', 'NOTSET')
	padW, padH = map(int, attrs.get('pads', (0, 0))[:2])
	kernelW, kernelH = map(int, attrs.get('kernel_shape', (0, 0))[:2])
	strideW, strideH = map(int, attrs.get('strides', (1, 1))[:2])
	dilationW, dilationH = map(int, attrs.get('dilations', (1, 1))[:2])
	group = int(attrs.get('group', 1))
	withBias = len(inputs) == 3 and inputs[2] is not None
	# extract inputs
	data_tensor_val, data_tensor_shape = inputs[0]
	weight_tensor_val, weight_tensor_shape = inputs[1]
	bias_tensor_val = None
	if withBias:
		bias_tensor_val, _ = inputs[2]
	# calc output tensor shape
	input_batch_size, n_input_plane = map(int, data_tensor_shape[:2])
	n_output_plane = weight_tensor_shape[0]
	input_height, input_width = data_tensor_shape[-2:]
	output_height = calc_output_shape(input_height, kernelH, padding = padH, stride=strideH)
	output_width = calc_output_shape(input_width, kernelW, padding = padW, stride=strideW)
	out_tensor_shape = (input_batch_size, n_output_plane, output_height, output_width)
	# create module node
	module = SpatialConvolution(
		n_input_plane=n_input_plane, n_output_plane=n_output_plane,
		kernel_w=kernelW, kernel_h=kernelH, stride_w=strideW, stride_h=strideH,
		pad_w=padW, pad_h=padH, n_group=group, init_weight=weight_tensor_val,
		init_bias=bias_tensor_val, with_bias=withBias
		)(prev_modules[0])
	return module, [out_tensor_shape] 
[docs]def gather(inputs, prev_modules, attrs, outputs):
	# extract attributes
	axis = int(attrs.get('axis', 0))
	if axis != 0:
		raise ValueError("Gather layer axis value")
	# extract inputs
	data_tensor_val, data_tensor_shape = inputs[0]
	indices_val, indices = inputs[1]
	# calc output tensor shape
	out_tensor_shape = tuple(data_tensor_shape[:axis] + indices + data_tensor_shape[axis + 1:])
	# create module node
	module = Gather()(prev_modules)
	return module, [out_tensor_shape] 
[docs]def gemm(inputs, prev_modules, attrs, outputs):
	# extract attributes
	alpha = float(attrs.get("alpha", 1.0))
	beta = float(attrs.get("beta", 1.0))
	trans_a = int(attrs.get("transA", 0))
	trans_b = int(attrs.get("transB", 0))
	# extract inputs
	_, tensor_a_shape = inputs[0]
	tensor_b_val, tensor_b_shape = inputs[1]
	tensor_c_val, tensor_c_shape = inputs[2]
	# create module node
	module = Gemm(alpha=alpha, beta=beta, trans_a=trans_a, trans_b=trans_b,
				matrix_b=tensor_b_val, matrix_c=tensor_c_val)(prev_modules)
	return module, [tensor_c_shape] 
[docs]def max_pool(inputs, prev_modules, attrs, outputs):
	# extract attributes
	auto_pad = attrs.get("auto_pad", 'NOTSET')
	kernelW, kernelH = map(int, attrs.get("kernel_shape")[:2])
	strideW, strideH = map(int, attrs.get("strides", (1, 1))[:2])
	dilationW, dilationH = map(int, attrs.get('dilations', (1, 1))[:2])
	padW, padH = map(int, attrs.get("pads", (0, 0))[:2])
	ceil_mode = True if (attrs.get("ceil_mode", 0) != 0) else False
	storage_order = int(attrs.get("storage_order", 0))
	# extract inputs
	_, data_tensor_shape = inputs[0]
	input_width, input_height = data_tensor_shape[-2:]
	# calc output tensor shape
	output_width = calc_output_shape(input_width, kernelW,
		padding=padW, stride=strideW, dilation=dilationW, ceil_mode=ceil_mode)
	output_height = calc_output_shape(input_height, kernelH,
		padding=padH, stride=strideH, dilation=dilationH, ceil_mode=ceil_mode)
	out_tensor_shape_list = list(data_tensor_shape)
	out_tensor_shape_list[2] = output_height
	out_tensor_shape_list[3] = output_width
	out_tensor_shape = tuple(out_tensor_shape_list)
	# create module node
	module = SpatialMaxPooling(kw=kernelW, kh=kernelH, dw=strideW, dh=strideH,
		pad_w=padW, pad_h=padH, to_ceil=ceil_mode)(prev_modules[0])
	return module, [out_tensor_shape] 
[docs]def relu(inputs, prev_modules, attrs, outputs):
	# extract inputs
	_, data_tensor_shape = inputs[0]
	# calc output tensor shape
	output_shape = data_tensor_shape
	# create module node
	module = ReLU()(prev_modules[0])
	return module, [output_shape] 
[docs]def reshape(inputs, prev_modules, attrs, outputs):
	# extract inputs
	_, data_tensor_shape = inputs[0]
	shape_tensor_val, _ = inputs[1]
	shape_arry = None
	if shape_tensor_val is not None:
		shape_arry = np.squeeze(shape_tensor_val).astype(int).tolist()
	# create module node
	module = Reshape(shape_arry)(prev_modules)
	return module, [shape_tensor_val] 
[docs]def shape(inputs, prev_modules, attrs, outputs):
	# extract inputs
	_, data_tensor_shape = inputs[0]
	# create module node
	module = Shape()(prev_modules[0])
	return module, [(len(data_tensor_shape),)] 
[docs]def softmax(inputs, prev_modules, attrs, outputs):
	_, data_tensor_shape = inputs[0]
	out_tensor_shape = data_tensor_shape
	axis = int(attrs.get('axis', 1))
	module = SoftMax()(prev_modules[0])
	return module, [out_tensor_shape] 
def _sum(inputs, prev_modules, attrs, outputs):
	_, data_tensor_shape = inputs[0]
	out_tensor_shape = data_tensor_shape
	module = CAddTable()(prev_modules)
	return module, [data_tensor_shape]
[docs]def unsqueeze(inputs, prev_modules, attrs, outputs):
	axes = list(map(int, attrs.get('axes')))
	data_tensor_val, data_tensor_shape = inputs[0]
	out_tensor_shape = list(data_tensor_shape)
	for idx in axes:
		out_tensor_shape.insert(idx, 1)
	out_tensor_shape = tuple(out_tensor_shape)
	module = Unsqueeze(axes[0], len(data_tensor_shape))(prev_modules)
	return module, [out_tensor_shape]