1. 程式人生 > >keras轉tensorflow lite【方法一】2步走

keras轉tensorflow lite【方法一】2步走

本文介紹的將keras轉tensorflow lite的方法,共分為兩步:

keras轉tensorflow,再將tensorflow轉tensorflow lite

 

1、keras-tensorflow:我是在mac上做的

如果relu6報錯,就放開這一行:

with CustomObjectScope({'relu6': mobilenet.relu6, 'DepthwiseConv2D': 

# coding: utf-8
'''
Input arguments:
num_output: this value has nothing to do with the number of classes, batch_size, etc., 
and it is mostly equal to 1. If the network is a **multi-stream network** 
(forked network with multiple outputs), set the value to the number of outputs.
quantize: if set to True, use the quantize feature of Tensorflow
(https://www.tensorflow.org/performance/quantization) [default: False]
use_theano: Thaeno and Tensorflow implement convolution in different ways.
When using Keras with Theano backend, the order is set to 'channels_first'.
This feature is not fully tested, and doesn't work with quantizization [default: False]
input_fld: directory holding the keras weights file [default: .]
output_fld: destination directory to save the tensorflow files [default: .]
input_model_file: name of the input weight file [default: 'model.h5']
output_model_file: name of the output weight file [default: args.input_model_file + '.pb']
graph_def: if set to True, will write the graph definition as an ascii file [default: False]
output_graphdef_file: if graph_def is set to True, the file name of the 
graph definition [default: model.ascii]
output_node_prefix: the prefix to use for output nodes. [default: output_node]
'''

# Parse input arguments
import argparse
parser = argparse.ArgumentParser(description='set input arguments')
parser.add_argument('-input_fld', action="store",
                    dest='input_fld', type=str, default='.')
parser.add_argument('-output_fld', action="store",
                    dest='output_fld', type=str, default='')
parser.add_argument('-input_model_file', action="store",
                    dest='input_model_file', type=str, default='model.h5')
parser.add_argument('-output_model_file', action="store",
                    dest='output_model_file', type=str, default='')
parser.add_argument('-output_graphdef_file', action="store",
                    dest='output_graphdef_file', type=str, default='model.ascii')
parser.add_argument('-num_outputs', action="store",
                    dest='num_outputs', type=int, default=1)
parser.add_argument('-graph_def', action="store",
                    dest='graph_def', type=bool, default=False)
parser.add_argument('-output_node_prefix', action="store",
                    dest='output_node_prefix', type=str, default='output_node')
parser.add_argument('-quantize', action="store",
                    dest='quantize', type=bool, default=False)
parser.add_argument('-theano_backend', action="store",
                    dest='theano_backend', type=bool, default=False)
parser.add_argument('-f')
args = parser.parse_args()
parser.print_help()
print('input args: ', args)

if args.theano_backend is True and args.quantize is True:
    raise ValueError("Quantize feature does not work with theano backend.")

# initialize
from keras.models import load_model
import tensorflow as tf
from pathlib import Path
from keras import backend as K
from keras.applications import mobilenet
from keras.utils.generic_utils import CustomObjectScope

output_fld =  args.input_fld if args.output_fld == '' else args.output_fld
if args.output_model_file == '':
    args.output_model_file = str(Path(args.input_model_file).name) + '.pb'
Path(output_fld).mkdir(parents=True, exist_ok=True)
weight_file_path = str(Path(args.input_fld) / args.input_model_file)

# Load keras model and rename output
K.set_learning_phase(0)
if args.theano_backend:
    K.set_image_data_format('channels_first')
else:
    K.set_image_data_format('channels_last')

# try:
# 主要修改在這裡,需要加上這行,否則會報錯
# with CustomObjectScope({'relu6': mobilenet.relu6, 'DepthwiseConv2D': mobilenet.DepthwiseConv2D}):
    net_model = load_model(weight_file_path)
# except ValueError as err:
#     print('''Input file specified ({}) only holds the weights, and not the model defenition.
#     Save the model using mode.save(filename.h5) which will contain the network architecture
#     as well as its weights.
#     If the model is saved using model.save_weights(filename.h5), the model architecture is
#     expected to be saved separately in a json format and loaded prior to loading the weights.
#     Check the keras documentation for more details (https://keras.io/getting-started/faq/)'''
#           .format(weight_file_path))
#     raise err
num_output = args.num_outputs
pred = [None]*num_output
pred_node_names = [None]*num_output
for i in range(num_output):
    pred_node_names[i] = args.output_node_prefix+str(i)
    pred[i] = tf.identity(net_model.outputs[i], name=pred_node_names[i])
print('output nodes names are: ', pred_node_names)


# [optional] write graph definition in ascii
sess = K.get_session()
if args.graph_def:
    f = args.output_graphdef_file
    tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)
    print('saved the graph definition in ascii format at: ', str(Path(output_fld) / f))


# convert variables to constants and save
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
if args.quantize:
    from tensorflow.tools.graph_transforms import TransformGraph
    transforms = ["quantize_weights", "quantize_nodes"]
    transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms)
    constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names)
else:
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
graph_io.write_graph(constant_graph, output_fld, args.output_model_file, as_text=False)
print('saved the freezed graph (ready for inference) at: ', str(Path(output_fld) / args.output_model_file))

結果:

 

2、tensorflow-tensorflow lite:linux下

用toco將.pb檔案轉成.tflite檔案

toco --graph_def_file /home/xuqiong/data/ssrnet_3_3_3_64_1_1_morph2_age_all.pb --output_format=TFLITE --output_file=/home/xuqiong/data/ssrnet_3_3_3_64_1_1_morph2_age_all.tflite --inference_type=FLOAT --input_arrays=input_1 --output_arrays=output_node0 --input_shapes=1,64,64,3

toco --graph_def_file /home/xuqiong/data/ssrnet_3_3_3_64_1_1_morph2_age_all.pb --output_format=TFLITE --output_file=/home/xuqiong/data/ssrnet_3_3_3_64_1_1_morph2_age_all.tflite --inference_type=FLOAT --input_arrays=input_1 --output_arrays=output_node0 --input_shapes=1,64,64,3

結果:

下載到本地後:

 

參考文件:https://juejin.im/post/5b7bccc6e51d453887102e0a

                 https://blog.csdn.net/mogoweb/article/details/80152774