Source code for bigdl.util.tf_utils

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import tempfile

import tensorflow as tf
import shutil

from google.protobuf import text_format

from tensorflow.core.framework import graph_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.platform import gfile
from bigdl.nn.layer import Model
from bigdl.util.common import JTensor
from bigdl.util.common import callBigDlFunc
import os

[docs]def get_path(output_name, sess=None): if sess is None: sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) temp = tempfile.mkdtemp() saver = tf.train.Saver() saver.save(sess, temp + '/model.chkp') tf.train.write_graph(sess.graph, temp, 'model.pbtxt') merge_checkpoint(temp + '/model.pbtxt', temp + '/model.chkp', [output_name], temp + '/model.pb', sess) return temp + '/model.pb'
[docs]def convert(input_ops, output_ops, byte_order, bigdl_type): """ Convert tensorflow model to bigdl model :param input_ops: operation list used for input, should be placeholders :param output_ops: operations list used for output :return: bigdl model """ input_names = map(lambda x: x.name.split(":")[0], input_ops) output_names = map(lambda x: x.name.split(":")[0], output_ops) temp = tempfile.mkdtemp() dump_model(path=temp) model_path = temp + '/model.pb' bin_path = temp + '/model.bin' model = Model.load_tensorflow(model_path, input_names, output_names, byte_order, bin_path, bigdl_type) try: shutil.rmtree(temp) except OSError as e: if e.errno != errno.ENOENT: raise return model
[docs]def export_checkpoint(checkpoint_path): """ Export variable tensors from the checkpoint files. :param checkpoint_path: tensorflow checkpoint path :return: dictionary of tensor. The key is the variable name and the value is the numpy """ reader = tf.train.NewCheckpointReader(checkpoint_path) # Get tensor name list tensor_names = filter(lambda n: n!='global_step', reader.get_variable_to_shape_map().keys()) # Prepare key-value dictionary tensors = {} for tn in tensor_names: tensors[tn] = reader.get_tensor(tn) return tensors
[docs]def save_variable_bigdl(tensors, target_path, bigdl_type="float"): """ Save a variable dictionary to a Java object file, so it can be read by BigDL :param tensors: tensor dictionary :param target_path: where is the Java object file store :param bigdl_type: model variable numeric type :return: nothing """ import numpy as np jtensors = {} for tn in tensors.keys(): if not isinstance(tensors[tn], np.ndarray): value = np.array(tensors[tn]) else: value = tensors[tn] jtensors[tn] = JTensor.from_ndarray(value) callBigDlFunc(bigdl_type, "saveTensorDictionary", jtensors, target_path)
[docs]def dump_model(path, graph=None, sess=None, ckpt_file=None, bigdl_type="float"): """ Dump a tensorflow model to files. The graph will be dumped to path/model.pb, and the checkpoint will be dumped to path/model.bin :param path: dump folder path :param sess: if user pass in session, we assume that the variable of the graph in the session has been inited :param graph: tensorflow graph. Default use the default graph of the session :param bigdl_type: model variable numeric type :return: nothing """ if not os.path.isdir(path): raise ValueError("Folder " + path + " does not exist") temp = None if ckpt_file is None: if sess is None: sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) temp = tempfile.mkdtemp() ckpt_file = temp # dump checkpoint to temp files saver = tf.train.Saver() saver.save(sess, ckpt_file) # generate bin files tensors = export_checkpoint(ckpt_file) save_variable_bigdl(tensors, path + "/model.bin", bigdl_type) # dump grap to pb file graph = sess.graph if graph is None else graph with gfile.GFile(path + "/model.pb", "wb") as f: f.write(graph.as_graph_def().SerializeToString()) if temp is not None: try: shutil.rmtree(temp) except OSError as e: if e.errno != errno.ENOENT: raise
[docs]def merge_checkpoint(input_graph, checkpoint, output_node_names, output_graph, sess): """ Get the variable values from the checkpoint file, and merge them to the GraphDef file Args: input_graph: the GraphDef file, doesn't contain variable values checkpoint: the checkpoint file output_node_names: A list of string, the output names output_graph: String of the location and the name of the output graph """ restore_op_name = "save/restore_all" filename_tensor_name = "save/Const:0" input_graph_def = graph_pb2.GraphDef() with gfile.FastGFile(input_graph, "r") as f: text_format.Merge(f.read().decode("utf-8"), input_graph_def) for node in input_graph_def.node: node.device = "" importer.import_graph_def(input_graph_def, name="") sess.run([restore_op_name], {filename_tensor_name: checkpoint}) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names, variable_names_blacklist="" ) with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString())