TensorFlow 自定義模型導出:將 .ckpt 格式轉化為 .pb 格式
本文承接上文 TensorFlow-slim 訓練 CNN 分類模型(續),闡述通過 tf.contrib.slim
的函數 slim.learning.train
訓練的模型,怎麽通過人為的加入數據入口(即占位符)來克服無法用於圖像推斷的問題。要解決這個問題,最簡單和最省時的方法是模仿。我們模仿的代碼是 TensorFlow 實現的目標檢測 API 中的文件 exporter.py,該文件的目的正是要將 TensorFlow-slim 訓練的目標檢測模型由 .ckpt 格式轉化為.pb 格式,而且其代碼中人為添加占位符的操作也正是我們需求的。坦白的說,我會用 TensorFlow 的 tf.contrib.slim
言歸正傳,現在我們回到主題,怎麽加入占位符,將前一篇文章訓練的 CNN 分類器用於圖像分類。這個問題在我們知道通過模仿 exporter.py 就可以解決它的時候,就變得異常簡單了。我們先來理順一下解決這個問題的邏輯:
1.定義數據入口,即定義占位符 inputs = tf.placeholder(···);
2.將模型作用於占位符,得到數據出口,即分類結果;
3.將訓練文件從 .ckpt 格式轉化為 .pb 格式。
按照這個邏輯順序,下面我們詳細的來看一下自定義模型導出,即模型格式轉化的代碼(命名為 exporter.py,如果沒有特別說明,exporter.py 指的都是我們修改 TensorFlow 目標檢測中的 exporter.py 後的自定義文件):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 15:13:27 2018
@author: shirhe-lyh
"""
"""Functions to export inference graph.
Modified from: TensorFlow models/research/object_detection/export.py
"""
import logging
import os
import tempfile
import tensorflow as tf
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import saver as saver_lib
slim = tf.contrib.slim
# TODO: Replace with freeze_graph.freeze_graph_with_def_protos when
# newer version of Tensorflow becomes more common.
def freeze_graph_with_def_protos(
input_graph_def,
input_saver_def,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
clear_devices,
initializer_nodes,
variable_names_blacklist=‘‘):
"""Converts all variables in a graph and checkpoint into constants."""
del restore_op_name, filename_tensor_name # Unused by updated loading code.
# ‘input_checkpoint‘ may be a prefix if we‘re using Saver V2 format
if not saver_lib.checkpoint_exists(input_checkpoint):
raise ValueError(
"Input checkpoint ‘ + input_checkpoint + ‘ does not exist!")
if not output_node_names:
raise ValueError(
‘You must supply the name of a node to --output_node_names.‘)
# Remove all the explicit device specifications for this node. This helps
# to make the graph more portable.
if clear_devices:
for node in input_graph_def.node:
node.device = ‘‘
with tf.Graph().as_default():
tf.import_graph_def(input_graph_def, name=‘‘)
config = tf.ConfigProto(graph_options=tf.GraphOptions())
with session.Session(config=config) as sess:
if input_saver_def:
saver = saver_lib.Saver(saver_def=input_saver_def)
saver.restore(sess, input_checkpoint)
else:
var_list = {}
reader = pywrap_tensorflow.NewCheckpointReader(
input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
try:
tensor = sess.graph.get_tensor_by_name(key + ‘:0‘)
except KeyError:
# This tensor doesn‘t exist in the graph (for example
# it‘s ‘global_step‘ or a similar housekeeping element)
# so skip it.
continue
var_list[key] = tensor
saver = saver_lib.Saver(var_list=var_list)
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes)
variable_names_blacklist = (variable_names_blacklist.split(‘,‘) if
variable_names_blacklist else None)
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(‘,‘),
variable_names_blacklist=variable_names_blacklist)
return output_graph_def
def replace_variable_values_with_moving_averages(graph,
current_checkpoint_file,
new_checkpoint_file):
"""Replaces variable values in the checkpoint with their moving averages.
If the current checkpoint has shadow variables maintaining moving averages
of the variables defined in the graph, this function generates a new
checkpoint where the variables contain the values of their moving averages.
Args:
graph: A tf.Graph object.
current_checkpoint_file: A checkpoint both original variables and
their moving averages.
new_checkpoint_file: File path to write a new checkpoint.
"""
with graph.as_default():
variable_averages = tf.train.ExponentialMovingAverage(0.0)
ema_variables_to_restore = variable_averages.variables_to_restore()
with tf.Session() as sess:
read_saver = tf.train.Saver(ema_variables_to_restore)
read_saver.restore(sess, current_checkpoint_file)
write_saver = tf.train.Saver()
write_saver.save(sess, new_checkpoint_file)
def _image_tensor_input_placeholder(input_shape=None):
"""Returns input placeholder and a 4-D uint8 image tensor."""
if input_shape is None:
input_shape = (None, None, None, 3)
input_tensor = tf.placeholder(
dtype=tf.uint8, shape=input_shape, name=‘image_tensor‘)
return input_tensor, input_tensor
def _encoded_image_string_tensor_input_placeholder():
"""Returns input that accepts a batch of PNG or JPEG strings.
Returns:
A tuple of input placeholder and the output decoded images.
"""
batch_image_str_placeholder = tf.placeholder(
dtype=tf.string,
shape=[None],
name=‘encoded_image_string_tensor‘)
def decode(encoded_image_string_tensor):
image_tensor = tf.image.decode_image(encoded_image_string_tensor,
channels=3)
image_tensor.set_shape((None, None, 3))
return image_tensor
return (batch_image_str_placeholder,
tf.map_fn(
decode,
elems=batch_image_str_placeholder,
dtype=tf.uint8,
parallel_iterations=32,
back_prop=False))
input_placeholder_fn_map = {
‘image_tensor‘: _image_tensor_input_placeholder,
‘encoded_image_string_tensor‘:
_encoded_image_string_tensor_input_placeholder,
# ‘tf_example‘: _tf_example_input_placeholder,
}
def _add_output_tensor_nodes(postprocessed_tensors,
output_collection_name=‘inference_op‘):
"""Adds output nodes.
Adjust according to specified implementations.
Adds the following nodes for output tensors:
* classes: A float32 tensor of shape [batch_size] containing class
predictions.
Args:
postprocessed_tensors: A dictionary containing the following fields:
‘classes‘: [batch_size].
output_collection_name: Name of collection to add output tensors to.
Returns:
A tensor dict containing the added output tensor nodes.
"""
outputs = {}
classes = postprocessed_tensors.get(‘classes‘) # Assume containing ‘classes‘
outputs[‘classes‘] = tf.identity(classes, name=‘classes‘)
for output_key in outputs:
tf.add_to_collection(output_collection_name, outputs[output_key])
return outputs
def write_frozen_graph(frozen_graph_path, frozen_graph_def):
"""Writes frozen graph to disk.
Args:
frozen_graph_path: Path to write inference graph.
frozen_graph_def: tf.GraphDef holding frozen graph.
"""
with gfile.GFile(frozen_graph_path, ‘wb‘) as f:
f.write(frozen_graph_def.SerializeToString())
logging.info(‘%d ops in the final graph.‘, len(frozen_graph_def.node))
def write_saved_model(saved_model_path,
frozen_graph_def,
inputs,
outputs):
"""Writes SavedModel to disk.
If checkpoint_path is not None bakes the weights into the graph thereby
eliminating the need of checkpoint files during inference. If the model
was trained with moving averages, setting use_moving_averages to True
restores the moving averages, otherwise the original set of variables
is restored.
Args:
saved_model_path: Path to write SavedModel.
frozen_graph_def: tf.GraphDef holding frozen graph.
inputs: The input image tensor.
outputs: A tensor dictionary containing the outputs of a slim model.
"""
with tf.Graph().as_default():
with session.Session() as sess:
tf.import_graph_def(frozen_graph_def, name=‘‘)
builder = tf.saved_model.builder.SavedModelBuilder(
saved_model_path)
tensor_info_inputs = {
‘inputs‘: tf.saved_model.utils.build_tensor_info(inputs)}
tensor_info_outputs = {}
for k, v in outputs.items():
tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(
v)
detection_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs=tensor_info_inputs,
outputs=tensor_info_outputs,
method_name=signature_constants.PREDICT_METHOD_NAME))
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
detection_signature,
},
)
builder.save()
def write_graph_and_checkpoint(inference_graph_def,
model_path,
input_saver_def,
trained_checkpoint_prefix):
"""Writes the graph and the checkpoint into disk."""
for node in inference_graph_def.node:
node.device = ‘‘
with tf.Graph().as_default():
tf.import_graph_def(inference_graph_def, name=‘‘)
with session.Session() as sess:
saver = saver_lib.Saver(saver_def=input_saver_def,
save_relative_paths=True)
saver.restore(sess, trained_checkpoint_prefix)
saver.save(sess, model_path)
def _get_outputs_from_inputs(input_tensors, model,
output_collection_name):
inputs = tf.to_float(input_tensors)
preprocessed_inputs = model.preprocess(inputs)
output_tensors = model.predict(preprocessed_inputs)
postprocessed_tensors = model.postprocess(output_tensors)
return _add_output_tensor_nodes(postprocessed_tensors,
output_collection_name)
def _build_model_graph(input_type, model, input_shape,
output_collection_name, graph_hook_fn):
"""Build the desired graph."""
if input_type not in input_placeholder_fn_map:
raise ValueError(‘Unknown input type: {}‘.format(input_type))
placeholder_args = {}
if input_shape is not None:
if input_type != ‘image_tensor‘:
raise ValueError("Can only specify input shape for ‘image_tensor‘ "
‘inputs.‘)
placeholder_args[‘input_shape‘] = input_shape
placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type](
**placeholder_args)
outputs = _get_outputs_from_inputs(
input_tensors=input_tensors,
model=model,
output_collection_name=output_collection_name)
# Add global step to the graph
slim.get_or_create_global_step()
if graph_hook_fn: graph_hook_fn()
return outputs, placeholder_tensor
def export_inference_graph(input_type,
model,
trained_checkpoint_prefix,
output_directory,
input_shape=None,
use_moving_averages=None,
output_collection_name=‘inference_op‘,
additional_output_tensor_names=None,
graph_hook_fn=None):
"""Exports inference graph for the desired graph.
Args:
input_type: Type of input for the graph. Can be one of [‘image_tensor‘,
‘encoded_image_string_tensor‘, ‘tf_example‘]. In this file,
input_type must be ‘image_tensor‘.
model: A model defined by model.py.
trained_checkpoint_prefix: Path to the trained checkpoint file.
output_directory: Path to write outputs.
input_shape: Sets a fixed shape for an ‘image_tensor‘ input. If not
specified, will default to [None, None, None, 3].
use_moving_averages: A boolean indicating whether the
tf.train.ExponentialMovingAverage should be used or not.
output_collection_name: Name of collection to add output tensors to.
If None, does not add output tensors to a collection.
additional_output_tensor_names: List of additional output tensors to
include in the frozen graph.
"""
tf.gfile.MakeDirs(output_directory)
frozen_graph_path = os.path.join(output_directory,
‘frozen_inference_graph.pb‘)
saved_model_path = os.path.join(output_directory, ‘saved_model‘)
model_path = os.path.join(output_directory, ‘model.ckpt‘)
outputs, placeholder_tensor = _build_model_graph(
input_type=input_type,
model=model,
input_shape=input_shape,
output_collection_name=output_collection_name,
graph_hook_fn=graph_hook_fn)
saver_kwargs = {}
if use_moving_averages:
# This check is to be compatible with both version of SaverDef.
if os.path.isfile(trained_checkpoint_prefix):
saver_kwargs[‘write_version‘] = saver_pb2.SaverDef.V1
temp_checkpoint_prefix = tempfile.NamedTemporaryFile().name
else:
temp_checkpoint_prefix = tempfile.mkdtemp()
replace_variable_values_with_moving_averages(
tf.get_default_graph(), trained_checkpoint_prefix,
temp_checkpoint_prefix)
checkpoint_to_use = temp_checkpoint_prefix
else:
checkpoint_to_use = trained_checkpoint_prefix
saver = tf.train.Saver(**saver_kwargs)
input_saver_def = saver.as_saver_def()
write_graph_and_checkpoint(
inference_graph_def=tf.get_default_graph().as_graph_def(),
model_path=model_path,
input_saver_def=input_saver_def,
trained_checkpoint_prefix=checkpoint_to_use)
if additional_output_tensor_names is not None:
output_node_names = ‘,‘.join(outputs.keys()+
additional_output_tensor_names)
else:
output_node_names = ‘,‘.join(outputs.keys())
frozen_graph_def = freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=input_saver_def,
input_checkpoint=checkpoint_to_use,
output_node_names=output_node_names,
restore_op_name=‘save/restore_all‘,
filename_tensor_name=‘save/Const:0‘,
clear_devices=True,
initializer_nodes=‘‘)
write_frozen_graph(frozen_graph_path, frozen_graph_def)
write_saved_model(saved_model_path, frozen_graph_def,
placeholder_tensor, outputs)
首先看定義占位符的函數 _image_tensor_input_placeholder 和 _encoded_image_string_tensor_input_placeholder ,重點關註前一個函數,因為它的輸入為一個批量圖像組成的 4 維張量(正是我們需要的),這個函數僅僅定義了一個圖像占位符 input_tensor:
input_tensor = tf.placeholder(dtype=tf.uint8, shape=input_shape, name=‘image_tensor‘)
簡單至極。接下來看 _build_model_graph 函數,這個函數將數據輸入 input_tensor (第一個參數)通過模型 model (第二個參數)作用的結果 outputs 返回。其中引用的函數 _get_outputs_from_inputs,顧名思義,由輸入數據得到分類結果。它又引用了函數 _add_output_tensor_nodes,這個函數比較重要,因為它定義了數據輸出結點
outputs[‘classes‘] = tf.identity(classes, name=‘classes‘)
以上這些便是這個自定義文件 exporter.py 的精華,因為它實現了數據入口(name=‘image_tensor‘)和出口(name=‘classes‘)結點的定義。另一方面,這個自定義文件 exporter.py 可以作為模型導出的通用文件,而針對每一個特定的模型我們只需要修改與參數 model(表示某個特定模型) 相關的函數即可,而所有這些函數就是以上列出的函數。
為了描述的完整性,也來看一看剩下的不需要修改的函數。我們從主函數 export_inference_graph 開始,它是實際被調用的函數。它首先創建了用於保存輸出文件的文件夾,然後根據參數 model 創建了模型數據入口和出口,接下來的 if 語句是說,如果使用移動平均,則將原始 graph 中的變量用它的移動平均值來替換(函數 replace_variable_values_with_moving_averages)。再下來的 write_graph_and_checkpoint 函數相當於將上一篇文章的訓練輸出文件復制到當前指定的輸出路徑 output_directory,最後的函數 freeze_graph_with_def_protos 將 graph 中的變量變成常量,然後通過函數 write_frozen_graph 和函數 write_saved_model 寫出到輸出路徑。
最後來解釋一下函數
export_inference_graph(input_type,
model,
trained_checkpoint_prefix,
output_directory,
input_shape=None,
use_moving_averages=None,
output_collection_name=‘inference_op‘,
additional_output_tensor_names=None,
graph_hook_fn=None)
的各個參數:1.input_type,指的是輸入數據的類型,exporter.py 指定了只能從以下的字典中
input_placeholder_fn_map = {
‘image_tensor‘: _image_tensor_input_placeholder,
‘encoded_image_string_tensor‘:
_encoded_image_string_tensor_input_placeholder,
# ‘tf_example‘: _tf_example_input_placeholder,
}
選出其中一種,一般我們選擇圖像作為輸入,即 image_tensor;2.model,指的是自己構建的模型,是一個類對象,如上一篇文章定義的 Model 類的一個實例:
cls_model = model.Model(is_training=False, num_classes=10)
3.trained_checkpoint_prefix,指定要導出的 .ckpt 文件路徑;4.output_directory,指定導出文件的存儲路徑(是一個文件夾);5.input_shape,輸入數據的形狀,缺省時為 [None, None, None, 3];6.use_moving_average,是否使用移動平均;7.output_collection_name,輸出的 collection 名,直接使用默認名,不需要修改;8.additional_output_tensor_names,指定額外的輸出張量名;9.graph_hook_fn,意義不明,暫時不知道它的表示意義。
實際調用的時候,我們一般只需要指定前四個參數,如(命名為 export_inference_graph.py):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 15:46:16 2018
@author: shirhe-lyh
"""
"""Tool to export a model for inference.
Outputs inference graph, asscociated checkpoint files, a frozen inference
graph and a SavedModel (https://tensorflow.github.io/serving_basic.html).
The inference graph contains one of three input nodes depending on the user
specified option.
* ‘image_tensor‘: Accepts a uint8 4-D tensor of shape [None, None, None, 3]
* ‘encoded_image_string_tensor‘: Accepts a 1-D string tensor of shape
[None] containg encoded PNG or JPEG images.
* ‘tf_example‘: Accepts a 1-D string tensor of shape [None] containing
serialized TFExample protos.
and the following output nodes returned by the model.postprocess(..):
* ‘classes‘: Outputs float32 tensors of the form [batch_size] containing
the classes for the predictions.
Example Usage:
---------------
python/python3 export_inference_graph --input_type image_tensor --trained_checkpoint_prefix path/to/model.ckpt --output_directory path/to/exported_model_directory
The exported output would be in the directory
path/to/exported_model_directory (which is created if it does not exist)
with contents:
- model.ckpt.data-00000-of-00001
- model.ckpt.info
- model.ckpt.meta
- frozen_inference_graph.pb
+ saved_model (a directory)
"""
import tensorflow as tf
import exporter
import model
slim = tf.contrib.slim
flags = tf.app.flags
flags.DEFINE_string(‘input_type‘, ‘image_tensor‘, ‘Type of input node. Can ‘
"be one of [‘image_tensor‘, ‘encoded_image_string_tensor‘"
", ‘tf_example‘]")
flags.DEFINE_string(‘input_shape‘, None, "If input_type is ‘image_tensor‘, "
"this can be explicitly set the shape of this input "
"to a fixed size. The dimensions are to be provided as a "
"comma-seperated list of integers. A value of -1 can be "
"used for unknown dimensions. If not specified, for an "
"‘image_tensor‘, the default shape will be partially "
"specified as ‘[None, None, None, 3]‘.")
flags.DEFINE_string(‘trained_checkpoint_prefix‘, None,
‘Path to trained checkpoint, typically of the form ‘
‘path/to/model.ckpt‘)
flags.DEFINE_string(‘output_directory‘, None, ‘Path to write outputs‘)
tf.app.flags.mark_flag_as_required(‘trained_checkpoint_prefix‘)
tf.app.flags.mark_flag_as_required(‘output_directory‘)
FLAGS = flags.FLAGS
def main(_):
cls_model = model.Model(is_training=False, num_classes=10)
if FLAGS.input_shape:
input_shape = [
int(dim) if dim != -1 else None
for dim in FLAGS.input_shape.split(‘,‘)
]
else:
input_shape = [None, 28, 28, 3]
exporter.export_inference_graph(FLAGS.input_type,
cls_model,
FLAGS.trained_checkpoint_prefix,
FLAGS.output_directory,
input_shape)
if __name__ == ‘__main__‘:
tf.app.run()
在終端運行命令:
python3 export_inference_graph.py --trained_checkpoint_prefix path/to/.ckpt-xxxx --output_directory path/to/output/directory
很快會在 output_directory 指定的文件夾中生成一系列文件,其中的 frozen_inference_graph.pb 便是我們需要的最終用於推斷的文件。至於如何讀取 .pb 文件用於推斷,則可以訪問這個系列的文章 TensorFlow 模型保存與恢復 的第二部分。為了方便閱讀,我們承接上一篇文章,使用如下代碼來對訓練的模型進行驗證:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 2 14:02:05 2018
@author: shirhe-lyh
"""
"""Evaluate the trained CNN model.
Example Usage:
---------------
python3 evaluate.py --frozen_graph_path: Path to model frozen graph.
"""
import numpy as np
import tensorflow as tf
from captcha.image import ImageCaptcha
flags = tf.app.flags
flags.DEFINE_string(‘frozen_graph_path‘, None, ‘Path to model frozen graph.‘)
FLAGS = flags.FLAGS
def generate_captcha(text=‘1‘):
capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
image = capt.generate_image(text)
image = np.array(image, dtype=np.uint8)
return image
def main(_):
model_graph = tf.Graph()
with model_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(FLAGS.frozen_graph_path, ‘rb‘) as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name=‘‘)
with model_graph.as_default():
with tf.Session(graph=model_graph) as sess:
inputs = model_graph.get_tensor_by_name(‘image_tensor:0‘)
classes = model_graph.get_tensor_by_name(‘classes:0‘)
for i in range(10):
label = np.random.randint(0, 10)
image = generate_captcha(str(label))
image_np = np.expand_dims(image, axis=0)
predicted_label = sess.run(classes,
feed_dict={inputs: image_np})
print(predicted_label, ‘ vs ‘, label)
if __name__ == ‘__main__‘:
tf.app.run()
簡單運行:
python3 evaluate.py --frozen_graph_path path/to/frozen_inference_graph.pb
可以看到驗證結果。
本文(及前文)的所有代碼都在 github: slim_cnn_test,歡迎訪問並下載。
預告:下一篇文章將介紹 TensorFlow 如何使用預訓練文件來精調分類模型。
作者:公輸睚信
鏈接:https://www.jianshu.com/p/243d4f0b656c
來源:簡書
簡書著作權歸作者所有,任何形式的轉載都請聯系作者獲得授權並註明出處。
TensorFlow 自定義模型導出:將 .ckpt 格式轉化為 .pb 格式