1. 程式人生 > >模型的儲存與載入

模型的儲存與載入

tensorflow:
有兩種方式儲存和載入模型。
①生成checkpoint file,副檔名為.ckpt,通過在tf.train.Saver物件上呼叫Saver.save()生成。包含權重和變數,但不包括圖的結構。如果需要在另一個程式中使用,需要重新建立圖形結構,並告訴Tesorflow如何處理這些權重。
模型儲存:

# 儲存變數,位於tf.train.Saver()後的變數將不會被儲存
saver = tf.train.Saver()
# ....
# 儲存模型, global_step是計數器,為訓練輪數計數
# 更新計數器
global_step.assign(i).eval()
# 載入所有引數
saver.save(sess, ckpt_dir = "/model.ckpt", global_step = global_step)

載入模型:
使用saver.restore進行模型載入。

ckpt = tf.train.get_checkpoint_state(ckpt_dir)
if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)

②生成圖協議檔案(graph proto file),二進位制檔案,.pb,用tf.train.write_graph()儲存。只包含圖形結構,不包含權重,然後使用tf.import_graph_def()來載入圖形。

# 當僅儲存圖模型時,才將圖寫入二進位制協議檔案中
v = tf.Variable(0, name = 'my_variable')
sess = tf.Session()
tf.train.write_grapg(sess.graph_def, '/tmp/tfmodel', 'train.pbtxt')
# 當讀取時,又從協議檔案中讀取出來
with tf.Session() as _sess:
    with gfile.FastGFile("/tmp/tfmodel/train.pbtxt", 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read
()) _sess.graph.as_default() tf.import_graph_def(graph_def, name = 'tfgraph')

keras:
keras的_model和load_model方法可以將keras模型和權重儲存在一個hdf5檔案中,這裡麵包括模型的結構,權重,訓練的配置(損失函式,優化器)等。如果訓練因為某個原因終中止,就用這個hdf5檔案從上次訓練的地方重新開始訓練。
模型的載入和儲存:
①模型的結構和權重都儲存

from keras.models import save_model, load_model
# 建立一個HDF5檔案
_, frame = tempfile.mkstemp('.h5')
save_model(model, fname)
new_model = load_model(fname)

os.remove(fname)

②只儲存模型的結構,不儲存其權重及訓練的配置(損失函式,優化器)
儲存時可將模型序列化為json或者yaml檔案:

json_string = model.to_json()
yaml_string = model.to_yaml()

儲存完成後,可用如下語句載入:

from keras.models import model_from_json
model = model_from_json(json_string)
model = model_from_yaml(yaml_string)

③僅儲存模型的權重,不包含網路的結構

model.save_weights('my_model_weights.h5')
mode.load_weights('my_model_weights.h5')