1. 程式人生 > >TensorFlow 自定義模型導出:將 .ckpt 格式轉化為 .pb 格式

TensorFlow 自定義模型導出:將 .ckpt 格式轉化為 .pb 格式

clear sin onf iat arr keys 部分 use oci

本文承接上文 TensorFlow-slim 訓練 CNN 分類模型(續),闡述通過 tf.contrib.slim 的函數 slim.learning.train 訓練的模型,怎麽通過人為的加入數據入口(即占位符)來克服無法用於圖像推斷的問題。要解決這個問題,最簡單和最省時的方法是模仿。我們模仿的代碼是 TensorFlow 實現的目標檢測 API 中的文件 exporter.py,該文件的目的正是要將 TensorFlow-slim 訓練的目標檢測模型由 .ckpt 格式轉化為.pb 格式,而且其代碼中人為添加占位符的操作也正是我們需求的。坦白的說,我會用 TensorFlow 的 tf.contrib.slim

模塊來構建和訓練模型正是受 TensorFlow models 項目的影響,當時我需要訓練目標檢測器,因此變配置了 models 這個子項目,並且從頭到尾的閱讀了其中 object_detection 中的 Faster RCNN 的源代碼,切實感受到了 slim 模塊的簡便和高效(學習 TensorFlow 最好的辦法除了查閱文檔之外,便是看 models 中各種項目的源代碼)。

言歸正傳,現在我們回到主題,怎麽加入占位符,將前一篇文章訓練的 CNN 分類器用於圖像分類。這個問題在我們知道通過模仿 exporter.py 就可以解決它的時候,就變得異常簡單了。我們先來理順一下解決這個問題的邏輯:

1.定義數據入口,即定義占位符 inputs = tf.placeholder(···);
2.將模型作用於占位符,得到數據出口,即分類結果;
3.將訓練文件從 .ckpt 格式轉化為 .pb 格式。

按照這個邏輯順序,下面我們詳細的來看一下自定義模型導出,即模型格式轉化的代碼(命名為 exporter.py,如果沒有特別說明,exporter.py 指的都是我們修改 TensorFlow 目標檢測中的 exporter.py 後的自定義文件):

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 15:13:27 2018
@author: shirhe-lyh
"""

"""Functions to export inference graph.
Modified from: TensorFlow models/research/object_detection/export.py
"""

import logging
import os
import tempfile
import tensorflow as tf

from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import saver as saver_lib

slim = tf.contrib.slim


# TODO: Replace with freeze_graph.freeze_graph_with_def_protos when
# newer version of Tensorflow becomes more common.
def freeze_graph_with_def_protos(
    input_graph_def,
    input_saver_def,
    input_checkpoint,
    output_node_names,
    restore_op_name,
    filename_tensor_name,
    clear_devices,
    initializer_nodes,
    variable_names_blacklist=‘‘):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.
    
    # ‘input_checkpoint‘ may be a prefix if we‘re using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        raise ValueError(
            "Input checkpoint ‘ + input_checkpoint + ‘ does not exist!")
        
    if not output_node_names:
        raise ValueError(
            ‘You must supply the name of a node to --output_node_names.‘)
        
    # Remove all the explicit device specifications for this node. This helps
    # to make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ‘‘
    
    with tf.Graph().as_default():
        tf.import_graph_def(input_graph_def, name=‘‘)
        config = tf.ConfigProto(graph_options=tf.GraphOptions())
        with session.Session(config=config) as sess:
            if input_saver_def:
                saver = saver_lib.Saver(saver_def=input_saver_def)
                saver.restore(sess, input_checkpoint)
            else:
                var_list = {}
                reader = pywrap_tensorflow.NewCheckpointReader(
                    input_checkpoint)
                var_to_shape_map = reader.get_variable_to_shape_map()
                for key in var_to_shape_map:
                    try:
                        tensor = sess.graph.get_tensor_by_name(key + ‘:0‘)
                    except KeyError:
                        # This tensor doesn‘t exist in the graph (for example
                        # it‘s ‘global_step‘ or a similar housekeeping element)
                        # so skip it.
                        continue
                    var_list[key] = tensor
                saver = saver_lib.Saver(var_list=var_list)
                saver.restore(sess, input_checkpoint)
                if initializer_nodes:
                    sess.run(initializer_nodes)
            
            variable_names_blacklist = (variable_names_blacklist.split(‘,‘) if
                                        variable_names_blacklist else None)
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.split(‘,‘),
                variable_names_blacklist=variable_names_blacklist)
    return output_graph_def


def replace_variable_values_with_moving_averages(graph,
                                                 current_checkpoint_file,
                                                 new_checkpoint_file):
    """Replaces variable values in the checkpoint with their moving averages.
    
    If the current checkpoint has shadow variables maintaining moving averages
    of the variables defined in the graph, this function generates a new
    checkpoint where the variables contain the values of their moving averages.
    
    Args:
        graph: A tf.Graph object.
        current_checkpoint_file: A checkpoint both original variables and
            their moving averages.
        new_checkpoint_file: File path to write a new checkpoint.
    """
    with graph.as_default():
        variable_averages = tf.train.ExponentialMovingAverage(0.0)
        ema_variables_to_restore = variable_averages.variables_to_restore()
        with tf.Session() as sess:
            read_saver = tf.train.Saver(ema_variables_to_restore)
            read_saver.restore(sess, current_checkpoint_file)
            write_saver = tf.train.Saver()
            write_saver.save(sess, new_checkpoint_file)


def _image_tensor_input_placeholder(input_shape=None):
    """Returns input placeholder and a 4-D uint8 image tensor."""
    if input_shape is None:
        input_shape = (None, None, None, 3)
    input_tensor = tf.placeholder(
        dtype=tf.uint8, shape=input_shape, name=‘image_tensor‘)
    return input_tensor, input_tensor


def _encoded_image_string_tensor_input_placeholder():
    """Returns input that accepts a batch of PNG or JPEG strings.
    
    Returns:
        A tuple of input placeholder and the output decoded images.
    """
    batch_image_str_placeholder = tf.placeholder(
        dtype=tf.string,
        shape=[None],
        name=‘encoded_image_string_tensor‘)
    def decode(encoded_image_string_tensor):
        image_tensor = tf.image.decode_image(encoded_image_string_tensor,
                                             channels=3)
        image_tensor.set_shape((None, None, 3))
        return image_tensor
    return (batch_image_str_placeholder,
            tf.map_fn(
                decode,
                elems=batch_image_str_placeholder,
                dtype=tf.uint8,
                parallel_iterations=32,
                back_prop=False))


input_placeholder_fn_map = {
    ‘image_tensor‘: _image_tensor_input_placeholder,
    ‘encoded_image_string_tensor‘:
        _encoded_image_string_tensor_input_placeholder,
#    ‘tf_example‘: _tf_example_input_placeholder,
    }


def _add_output_tensor_nodes(postprocessed_tensors,
                             output_collection_name=‘inference_op‘):
    """Adds output nodes.
    
    Adjust according to specified implementations.
    
    Adds the following nodes for output tensors:
        * classes: A float32 tensor of shape [batch_size] containing class
            predictions.
    
    Args:
        postprocessed_tensors: A dictionary containing the following fields:
            ‘classes‘: [batch_size].
        output_collection_name: Name of collection to add output tensors to.
        
    Returns:
        A tensor dict containing the added output tensor nodes.
    """
    outputs = {}
    classes = postprocessed_tensors.get(‘classes‘) # Assume containing ‘classes‘
    outputs[‘classes‘] = tf.identity(classes, name=‘classes‘)
    for output_key in outputs:
        tf.add_to_collection(output_collection_name, outputs[output_key])
    return outputs


def write_frozen_graph(frozen_graph_path, frozen_graph_def):
    """Writes frozen graph to disk.
    
    Args:
        frozen_graph_path: Path to write inference graph.
        frozen_graph_def: tf.GraphDef holding frozen graph.
    """
    with gfile.GFile(frozen_graph_path, ‘wb‘) as f:
        f.write(frozen_graph_def.SerializeToString())
    logging.info(‘%d ops in the final graph.‘, len(frozen_graph_def.node))
    
    
def write_saved_model(saved_model_path,
                      frozen_graph_def,
                      inputs,
                      outputs):
    """Writes SavedModel to disk.
    
    If checkpoint_path is not None bakes the weights into the graph thereby
    eliminating the need of checkpoint files during inference. If the model
    was trained with moving averages, setting use_moving_averages to True
    restores the moving averages, otherwise the original set of variables
    is restored.
    
    Args:
        saved_model_path: Path to write SavedModel.
        frozen_graph_def: tf.GraphDef holding frozen graph.
        inputs: The input image tensor.
        outputs: A tensor dictionary containing the outputs of a slim model.
    """
    with tf.Graph().as_default():
        with session.Session() as sess:
            tf.import_graph_def(frozen_graph_def, name=‘‘)
            
            builder = tf.saved_model.builder.SavedModelBuilder(
                saved_model_path)
            
            tensor_info_inputs = {
                ‘inputs‘: tf.saved_model.utils.build_tensor_info(inputs)}
            tensor_info_outputs = {}
            for k, v in outputs.items():
                tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(
                    v)
                
            detection_signature = (
                tf.saved_model.signature_def_utils.build_signature_def(
                    inputs=tensor_info_inputs,
                    outputs=tensor_info_outputs,
                    method_name=signature_constants.PREDICT_METHOD_NAME))
            
            builder.add_meta_graph_and_variables(
                sess, [tf.saved_model.tag_constants.SERVING],
                signature_def_map={
                    signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                        detection_signature,
                        },
            )
            builder.save()


def write_graph_and_checkpoint(inference_graph_def,
                               model_path,
                               input_saver_def,
                               trained_checkpoint_prefix):
    """Writes the graph and the checkpoint into disk."""
    for node in inference_graph_def.node:
        node.device = ‘‘
    with tf.Graph().as_default():
        tf.import_graph_def(inference_graph_def, name=‘‘)
        with session.Session() as sess:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    save_relative_paths=True)
            saver.restore(sess, trained_checkpoint_prefix)
            saver.save(sess, model_path)


def _get_outputs_from_inputs(input_tensors, model, 
                             output_collection_name):
    inputs = tf.to_float(input_tensors)
    preprocessed_inputs = model.preprocess(inputs)
    output_tensors = model.predict(preprocessed_inputs)
    postprocessed_tensors = model.postprocess(output_tensors)
    return _add_output_tensor_nodes(postprocessed_tensors,
                                    output_collection_name)
    
    
def _build_model_graph(input_type, model, input_shape, 
                           output_collection_name, graph_hook_fn):
    """Build the desired graph."""
    if input_type not in input_placeholder_fn_map:
        raise ValueError(‘Unknown input type: {}‘.format(input_type))
    placeholder_args = {}
    if input_shape is not None:
        if input_type != ‘image_tensor‘:
            raise ValueError("Can only specify input shape for ‘image_tensor‘ "
                             ‘inputs.‘)
        placeholder_args[‘input_shape‘] = input_shape
    placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type](
        **placeholder_args)
    outputs = _get_outputs_from_inputs(
        input_tensors=input_tensors,
        model=model,
        output_collection_name=output_collection_name)
    
    # Add global step to the graph
    slim.get_or_create_global_step()
    
    if graph_hook_fn: graph_hook_fn()
    
    return outputs, placeholder_tensor


def export_inference_graph(input_type,
                           model,
                           trained_checkpoint_prefix,
                           output_directory,
                           input_shape=None,
                           use_moving_averages=None,
                           output_collection_name=‘inference_op‘,
                           additional_output_tensor_names=None,
                           graph_hook_fn=None):
    """Exports inference graph for the desired graph.
    
    Args:
        input_type: Type of input for the graph. Can be one of [‘image_tensor‘,
            ‘encoded_image_string_tensor‘, ‘tf_example‘]. In this file, 
            input_type must be ‘image_tensor‘.
        model: A model defined by model.py.
        trained_checkpoint_prefix: Path to the trained checkpoint file.
        output_directory: Path to write outputs.
        input_shape: Sets a fixed shape for an ‘image_tensor‘ input. If not
            specified, will default to [None, None, None, 3].
        use_moving_averages: A boolean indicating whether the 
            tf.train.ExponentialMovingAverage should be used or not.
        output_collection_name: Name of collection to add output tensors to.
            If None, does not add output tensors to a collection.
        additional_output_tensor_names: List of additional output tensors to
            include in the frozen graph.
    """
    tf.gfile.MakeDirs(output_directory)
    frozen_graph_path = os.path.join(output_directory,
                                     ‘frozen_inference_graph.pb‘)
    saved_model_path = os.path.join(output_directory, ‘saved_model‘)
    model_path = os.path.join(output_directory, ‘model.ckpt‘)
    
    outputs, placeholder_tensor = _build_model_graph(
        input_type=input_type,
        model=model,
        input_shape=input_shape,
        output_collection_name=output_collection_name,
        graph_hook_fn=graph_hook_fn)
    
    saver_kwargs = {}
    if use_moving_averages:
        # This check is to be compatible with both version of SaverDef.
        if os.path.isfile(trained_checkpoint_prefix):
            saver_kwargs[‘write_version‘] = saver_pb2.SaverDef.V1
            temp_checkpoint_prefix = tempfile.NamedTemporaryFile().name
        else:
            temp_checkpoint_prefix = tempfile.mkdtemp()
        replace_variable_values_with_moving_averages(
            tf.get_default_graph(), trained_checkpoint_prefix,
            temp_checkpoint_prefix)
        checkpoint_to_use = temp_checkpoint_prefix
    else:
        checkpoint_to_use = trained_checkpoint_prefix
    
    saver = tf.train.Saver(**saver_kwargs)
    input_saver_def = saver.as_saver_def()
    
    write_graph_and_checkpoint(
        inference_graph_def=tf.get_default_graph().as_graph_def(),
        model_path=model_path,
        input_saver_def=input_saver_def,
        trained_checkpoint_prefix=checkpoint_to_use)
    
    if additional_output_tensor_names is not None:
        output_node_names = ‘,‘.join(outputs.keys()+
                                     additional_output_tensor_names)
    else:
        output_node_names = ‘,‘.join(outputs.keys())
        
    frozen_graph_def = freeze_graph_with_def_protos(
        input_graph_def=tf.get_default_graph().as_graph_def(),
        input_saver_def=input_saver_def,
        input_checkpoint=checkpoint_to_use,
        output_node_names=output_node_names,
        restore_op_name=‘save/restore_all‘,
        filename_tensor_name=‘save/Const:0‘,
        clear_devices=True,
        initializer_nodes=‘‘)
    write_frozen_graph(frozen_graph_path, frozen_graph_def)
    write_saved_model(saved_model_path, frozen_graph_def,
                      placeholder_tensor, outputs)

首先看定義占位符的函數 _image_tensor_input_placeholder_encoded_image_string_tensor_input_placeholder ,重點關註前一個函數,因為它的輸入為一個批量圖像組成的 4 維張量(正是我們需要的),這個函數僅僅定義了一個圖像占位符 input_tensor

input_tensor = tf.placeholder(dtype=tf.uint8, shape=input_shape, name=‘image_tensor‘)

簡單至極。接下來看 _build_model_graph 函數,這個函數將數據輸入 input_tensor (第一個參數)通過模型 model (第二個參數)作用的結果 outputs 返回。其中引用的函數 _get_outputs_from_inputs,顧名思義,由輸入數據得到分類結果。它又引用了函數 _add_output_tensor_nodes,這個函數比較重要,因為它定義了數據輸出結點

outputs[‘classes‘] = tf.identity(classes, name=‘classes‘)

以上這些便是這個自定義文件 exporter.py 的精華,因為它實現了數據入口(name=‘image_tensor‘)和出口(name=‘classes‘)結點的定義。另一方面,這個自定義文件 exporter.py 可以作為模型導出的通用文件,而針對每一個特定的模型我們只需要修改與參數 model(表示某個特定模型) 相關的函數即可,而所有這些函數就是以上列出的函數。

為了描述的完整性,也來看一看剩下的不需要修改的函數。我們從主函數 export_inference_graph 開始,它是實際被調用的函數。它首先創建了用於保存輸出文件的文件夾,然後根據參數 model 創建了模型數據入口和出口,接下來的 if 語句是說,如果使用移動平均,則將原始 graph 中的變量用它的移動平均值來替換(函數 replace_variable_values_with_moving_averages)。再下來的 write_graph_and_checkpoint 函數相當於將上一篇文章的訓練輸出文件復制到當前指定的輸出路徑 output_directory,最後的函數 freeze_graph_with_def_protosgraph 中的變量變成常量,然後通過函數 write_frozen_graph 和函數 write_saved_model 寫出到輸出路徑。

最後來解釋一下函數

export_inference_graph(input_type,
                       model,
                       trained_checkpoint_prefix,
                       output_directory,
                       input_shape=None,
                       use_moving_averages=None,
                       output_collection_name=‘inference_op‘,
                       additional_output_tensor_names=None,
                       graph_hook_fn=None)

的各個參數:1.input_type,指的是輸入數據的類型,exporter.py 指定了只能從以下的字典中

input_placeholder_fn_map = {
    ‘image_tensor‘: _image_tensor_input_placeholder,
    ‘encoded_image_string_tensor‘:
        _encoded_image_string_tensor_input_placeholder,
#    ‘tf_example‘: _tf_example_input_placeholder,
    }

選出其中一種,一般我們選擇圖像作為輸入,即 image_tensor;2.model,指的是自己構建的模型,是一個類對象,如上一篇文章定義的 Model 類的一個實例:

cls_model = model.Model(is_training=False, num_classes=10)

3.trained_checkpoint_prefix,指定要導出的 .ckpt 文件路徑;4.output_directory,指定導出文件的存儲路徑(是一個文件夾);5.input_shape,輸入數據的形狀,缺省時為 [None, None, None, 3];6.use_moving_average,是否使用移動平均;7.output_collection_name,輸出的 collection 名,直接使用默認名,不需要修改;8.additional_output_tensor_names,指定額外的輸出張量名;9.graph_hook_fn,意義不明,暫時不知道它的表示意義。

實際調用的時候,我們一般只需要指定前四個參數,如(命名為 export_inference_graph.py):

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 15:46:16 2018
@author: shirhe-lyh
"""

"""Tool to export a model for inference.
Outputs inference graph, asscociated checkpoint files, a frozen inference
graph and a SavedModel (https://tensorflow.github.io/serving_basic.html).
The inference graph contains one of three input nodes depending on the user
specified option.
    * ‘image_tensor‘: Accepts a uint8 4-D tensor of shape [None, None, None, 3]
    * ‘encoded_image_string_tensor‘: Accepts a 1-D string tensor of shape 
        [None] containg encoded PNG or JPEG images.
    * ‘tf_example‘: Accepts a 1-D string tensor of shape [None] containing
        serialized TFExample protos.
        
and the following output nodes returned by the model.postprocess(..):
    * ‘classes‘: Outputs float32 tensors of the form [batch_size] containing
        the classes for the predictions.
        
Example Usage:
---------------
python/python3 export_inference_graph     --input_type image_tensor     --trained_checkpoint_prefix path/to/model.ckpt     --output_directory path/to/exported_model_directory
    
The exported output would be in the directory
path/to/exported_model_directory (which is created if it does not exist)
with contents:
    - model.ckpt.data-00000-of-00001
    - model.ckpt.info
    - model.ckpt.meta
    - frozen_inference_graph.pb
    + saved_model (a directory)
"""
import tensorflow as tf

import exporter
import model

slim = tf.contrib.slim
flags = tf.app.flags

flags.DEFINE_string(‘input_type‘, ‘image_tensor‘, ‘Type of input node. Can ‘
                    "be one of [‘image_tensor‘, ‘encoded_image_string_tensor‘"
                    ", ‘tf_example‘]")
flags.DEFINE_string(‘input_shape‘, None, "If input_type is ‘image_tensor‘, "
                    "this can be explicitly set the shape of this input "
                    "to a fixed size. The dimensions are to be provided as a "
                    "comma-seperated list of integers. A value of -1 can be "
                    "used for unknown dimensions. If not specified, for an "
                    "‘image_tensor‘, the default shape will be partially "
                    "specified as ‘[None, None, None, 3]‘.")
flags.DEFINE_string(‘trained_checkpoint_prefix‘, None,
                    ‘Path to trained checkpoint, typically of the form ‘
                    ‘path/to/model.ckpt‘)
flags.DEFINE_string(‘output_directory‘, None, ‘Path to write outputs‘)
tf.app.flags.mark_flag_as_required(‘trained_checkpoint_prefix‘)
tf.app.flags.mark_flag_as_required(‘output_directory‘)
FLAGS = flags.FLAGS


def main(_):
    cls_model = model.Model(is_training=False, num_classes=10)
    if FLAGS.input_shape:
        input_shape = [
            int(dim) if dim != -1 else None 
            for dim in FLAGS.input_shape.split(‘,‘)
        ]
    else:
        input_shape = [None, 28, 28, 3]
    exporter.export_inference_graph(FLAGS.input_type,
                                    cls_model,
                                    FLAGS.trained_checkpoint_prefix,
                                    FLAGS.output_directory,
                                    input_shape)
    

if __name__ == ‘__main__‘:
    tf.app.run()

在終端運行命令:

python3 export_inference_graph.py     --trained_checkpoint_prefix path/to/.ckpt-xxxx     --output_directory path/to/output/directory

很快會在 output_directory 指定的文件夾中生成一系列文件,其中的 frozen_inference_graph.pb 便是我們需要的最終用於推斷的文件。至於如何讀取 .pb 文件用於推斷,則可以訪問這個系列的文章 TensorFlow 模型保存與恢復 的第二部分。為了方便閱讀,我們承接上一篇文章,使用如下代碼來對訓練的模型進行驗證:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Apr  2 14:02:05 2018
@author: shirhe-lyh
"""

"""Evaluate the trained CNN model.
Example Usage:
---------------
python3 evaluate.py     --frozen_graph_path: Path to model frozen graph.
"""

import numpy as np
import tensorflow as tf

from captcha.image import ImageCaptcha

flags = tf.app.flags
flags.DEFINE_string(‘frozen_graph_path‘, None, ‘Path to model frozen graph.‘)
FLAGS = flags.FLAGS


def generate_captcha(text=‘1‘):
    capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
    image = capt.generate_image(text)
    image = np.array(image, dtype=np.uint8)
    return image


def main(_):
    model_graph = tf.Graph()
    with model_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(FLAGS.frozen_graph_path, ‘rb‘) as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name=‘‘)
    
    with model_graph.as_default():
        with tf.Session(graph=model_graph) as sess:
            inputs = model_graph.get_tensor_by_name(‘image_tensor:0‘)
            classes = model_graph.get_tensor_by_name(‘classes:0‘)
            for i in range(10):
                label = np.random.randint(0, 10)
                image = generate_captcha(str(label))
                image_np = np.expand_dims(image, axis=0)
                predicted_label = sess.run(classes, 
                                           feed_dict={inputs: image_np})
                print(predicted_label, ‘ vs ‘, label)
            
            
if __name__ == ‘__main__‘:
    tf.app.run()

簡單運行:

python3 evaluate.py --frozen_graph_path path/to/frozen_inference_graph.pb

可以看到驗證結果。

本文(及前文)的所有代碼都在 github: slim_cnn_test,歡迎訪問並下載。

預告:下一篇文章將介紹 TensorFlow 如何使用預訓練文件來精調分類模型。



作者:公輸睚信
鏈接:https://www.jianshu.com/p/243d4f0b656c
來源:簡書
簡書著作權歸作者所有,任何形式的轉載都請聯系作者獲得授權並註明出處。

TensorFlow 自定義模型導出:將 .ckpt 格式轉化為 .pb 格式