1. 程式人生 > >如何將keras訓練好的模型轉換成tensorflow的.pb的檔案並在TensorFlow serving環境呼叫

如何將keras訓練好的模型轉換成tensorflow的.pb的檔案並在TensorFlow serving環境呼叫

首先keras訓練好的模型通過自帶的model.save()儲存下來是 .model (.h5) 格式的檔案

模型載入是通過 my_model = keras . models . load_model( filepath )

要將該模型轉換為.pb 格式的TensorFlow 模型,程式碼如下:

  1 # -*- coding: utf-8 -*-
  2 from keras.layers.core import Activation, Dense, Flatten
  3 from keras.layers.embeddings import Embedding
  4 from keras.layers.recurrent import LSTM
  5 from keras.layers import Dropout
  6 from keras.layers.wrappers import Bidirectional
  7 from keras.models import Sequential,load_model
  8 from keras.preprocessing import sequence
  9 from sklearn.model_selection import train_test_split
 10 import collections
 11 from collections import defaultdict
 12 import jieba
 13 import numpy as np
 14 import sys
 15 reload(sys)
 16 sys.setdefaultencoding('utf-8')
 17 import tensorflow as tf
 18 import os
 19 import os.path as osp
 20 from keras import backend as K
 21 def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 22     from tensorflow.python.framework.graph_util import convert_variables_to_constants
 23     graph = session.graph
 24     with graph.as_default():
 25         freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
 26         output_names = output_names or []
 27         output_names += [v.op.name for v in tf.global_variables()]
 28         input_graph_def = graph.as_graph_def()
 29         if clear_devices:
 30             for node in input_graph_def.node:
 31                 node.device = ""
 32         frozen_graph = convert_variables_to_constants(session, input_graph_def,
 33                                                       output_names, freeze_var_names)
 34         return frozen_graph
 37 input_fld = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/'
 38 weight_file = 'biLSTM_brand_recognize.model'
 39 output_graph_name = 'tensor_model_v3.pb'
 40 
 41 output_fld = input_fld + '/tensorflow_model/'
 42 if not os.path.isdir(output_fld):
 43     os.mkdir(output_fld)
 44 weight_file_path = osp.join(input_fld, weight_file)
 45 
 46 K.set_learning_phase(0)
 47 net_model = load_model(weight_file_path)
 48 
 49 
 50 print('input is :', net_model.input.name)
 51 print ('output is:', net_model.output.name)
 52 
 53 sess = K.get_session()
 54 
 55 frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])
 57 from tensorflow.python.framework import graph_io
 58 
 59 graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=True)
 60 
 61 
 62 print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))

然後模型就存成了.pb格式的檔案

問題就來了,這樣存下來的.pb格式的檔案是frozen model

如果通過TensorFlow serving 啟用模型的話,會報錯:

 E tensorflow_serving/core/aspired_versions_manager.cc:358] Servable {name: mnist version: 1} cannot be loaded: Not found: Could not find meta graph def matching supplied tags: { serve }. To inspect available tag-sets in the SavedModel, please use the SavedModel CLI: `saved_model_cli`

因為TensorFlow serving 希望讀取的是saved model

於是需要將frozen model 轉化為 saved model 格式,解決方案如下:

 64 from tensorflow.python.saved_model import signature_constants
 65 from tensorflow.python.saved_model import tag_constants
 66 
 67 export_dir = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/saved_model'
 68 graph_pb = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/tensorflow_model/tensor_model.pb'
 69 
 70 builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
 71 
 72 with tf.gfile.GFile(graph_pb, "rb") as f:
 73     graph_def = tf.GraphDef()
 74     graph_def.ParseFromString(f.read())
 75 
 76 sigs = {}
 77 
 78 with tf.Session(graph=tf.Graph()) as sess:
 79     # name="" is important to ensure we don't get spurious prefixing
 80     tf.import_graph_def(graph_def, name="")
 81     g = tf.get_default_graph()
 82     inp = g.get_tensor_by_name(net_model.input.name)
 83     out = g.get_tensor_by_name(net_model.output.name)
 84 
 85     sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
 86         tf.saved_model.signature_def_utils.predict_signature_def(
 87             {"in": inp}, {"out": out})
 88 
 89     builder.add_meta_graph_and_variables(sess,
 90                                          [tag_constants.SERVING],
 91                                          signature_def_map=sigs)
 92 
 93 builder.save()
                   

於是儲存下來的saved model 資料夾下就有兩個檔案:

saved_model.pb   variables

其中variables 可以為空

於是將.pb 模型匯入serving再讀取,成功!