#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tempfile
import tensorflow as tf
import shutil
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.platform import gfile
from bigdl.nn.layer import Model
from bigdl.util.common import JTensor
from bigdl.util.common import callBigDlFunc
import os
[docs]def get_path(output_name, sess=None):
if sess is None:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
temp = tempfile.mkdtemp()
saver = tf.train.Saver()
saver.save(sess, temp + '/model.chkp')
tf.train.write_graph(sess.graph, temp, 'model.pbtxt')
merge_checkpoint(temp + '/model.pbtxt',
temp + '/model.chkp',
[output_name],
temp + '/model.pb', sess)
return temp + '/model.pb'
[docs]def convert(input_ops, output_ops, byte_order, bigdl_type):
"""
Convert tensorflow model to bigdl model
:param input_ops: operation list used for input, should be placeholders
:param output_ops: operations list used for output
:return: bigdl model
"""
input_names = map(lambda x: x.name.split(":")[0], input_ops)
output_names = map(lambda x: x.name.split(":")[0], output_ops)
temp = tempfile.mkdtemp()
dump_model(path=temp)
model_path = temp + '/model.pb'
bin_path = temp + '/model.bin'
model = Model.load_tensorflow(model_path, input_names, output_names,
byte_order, bin_path, bigdl_type)
try:
shutil.rmtree(temp)
except OSError as e:
if e.errno != errno.ENOENT:
raise
return model
[docs]def export_checkpoint(checkpoint_path):
"""
Export variable tensors from the checkpoint files.
:param checkpoint_path: tensorflow checkpoint path
:return: dictionary of tensor. The key is the variable name and the value is the numpy
"""
reader = tf.train.NewCheckpointReader(checkpoint_path)
# Get tensor name list
tensor_names = filter(lambda n: n!='global_step',
reader.get_variable_to_shape_map().keys())
# Prepare key-value dictionary
tensors = {}
for tn in tensor_names:
tensors[tn] = reader.get_tensor(tn)
return tensors
[docs]def save_variable_bigdl(tensors, target_path, bigdl_type="float"):
"""
Save a variable dictionary to a Java object file, so it can be read by BigDL
:param tensors: tensor dictionary
:param target_path: where is the Java object file store
:param bigdl_type: model variable numeric type
:return: nothing
"""
import numpy as np
jtensors = {}
for tn in tensors.keys():
if not isinstance(tensors[tn], np.ndarray):
value = np.array(tensors[tn])
else:
value = tensors[tn]
jtensors[tn] = JTensor.from_ndarray(value)
callBigDlFunc(bigdl_type, "saveTensorDictionary", jtensors, target_path)
[docs]def dump_model(path, graph=None, sess=None, ckpt_file=None, bigdl_type="float"):
"""
Dump a tensorflow model to files. The graph will be dumped to path/model.pb, and the checkpoint will
be dumped to path/model.bin
:param path: dump folder path
:param sess: if user pass in session, we assume that the variable of the graph in the session
has been inited
:param graph: tensorflow graph. Default use the default graph of the session
:param bigdl_type: model variable numeric type
:return: nothing
"""
if not os.path.isdir(path):
raise ValueError("Folder " + path + " does not exist")
temp = None
if ckpt_file is None:
if sess is None:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
temp = tempfile.mkdtemp()
ckpt_file = temp
# dump checkpoint to temp files
saver = tf.train.Saver()
saver.save(sess, ckpt_file)
# generate bin files
tensors = export_checkpoint(ckpt_file)
save_variable_bigdl(tensors, path + "/model.bin", bigdl_type)
# dump grap to pb file
graph = sess.graph if graph is None else graph
with gfile.GFile(path + "/model.pb", "wb") as f:
f.write(graph.as_graph_def().SerializeToString())
if temp is not None:
try:
shutil.rmtree(temp)
except OSError as e:
if e.errno != errno.ENOENT:
raise
[docs]def merge_checkpoint(input_graph,
checkpoint,
output_node_names,
output_graph,
sess):
"""
Get the variable values from the checkpoint file, and merge them to the GraphDef file
Args:
input_graph: the GraphDef file, doesn't contain variable values
checkpoint: the checkpoint file
output_node_names: A list of string, the output names
output_graph: String of the location and the name of the
output graph
"""
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
input_graph_def = graph_pb2.GraphDef()
with gfile.FastGFile(input_graph, "r") as f:
text_format.Merge(f.read().decode("utf-8"), input_graph_def)
for node in input_graph_def.node:
node.device = ""
importer.import_graph_def(input_graph_def, name="")
sess.run([restore_op_name], {filename_tensor_name: checkpoint})
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names,
variable_names_blacklist=""
)
with gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())