1. 程式人生 > >tensorflow-訓練檢查點tf.train.Saver

tensorflow-訓練檢查點tf.train.Saver


#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Sep  6 10:16:37 2018
@author: myhaspl
@email:[email protected]
"""

import tensorflow as tf
g1=tf.Graph()

with g1.as_default(): 
    with tf.name_scope("input_Variable"):        
        my_var=tf.Variable(1,dtype=tf.float32)
    with tf.name_scope("global_step"):
        my_step=tf.Variable(0,dtype=tf.int32)
    with tf.name_scope("update"):
        varop=tf.assign(my_var,tf.multiply(tf.log(tf.add(my_var,1)),1))
        stepop=tf.assign_add(my_step,1)
        addop=tf.group([varop,stepop])
    with tf.name_scope("summaries"):
        tf.summary.scalar('myvar',my_var)
    with tf.name_scope("global_ops"):
        init=tf.global_variables_initializer()
        merged_summaries=tf.summary.merge_all()

with tf.Session(graph=g1) as sess:  
    writer=tf.summary.FileWriter('sum_vars',sess.graph)
    sess.run(init)
    #---0
    step,var,summary=sess.run([my_step,my_var,merged_summaries])
    writer.add_summary(summary,global_step=step)
    print step,var
    saver=tf.train.Saver()
    #1-49
    for i in xrange(1,50):
        sess.run(addop)
        step,var,summary=sess.run([my_step,my_var,merged_summaries])
        writer.add_summary(summary,global_step=step)
        print step,var
        if i%5==0:
            saver.save(sess,'./myvar-model/myvar-model',global_step=i)
    saver.save(sess,'./myvar-model/myvar-model',global_step=49)

    writer.flush()
    writer.close()

38 0.0512373
39 0.04996785
40 0.048759546
41 0.04760808
42 0.04650955
43 0.045460388
44 0.04445735
45 0.04349747
46 0.042578023
47 0.041696515
48 0.040850647
49 0.04003831

儲存資料流圖的變數到二進位制檢查點檔案。

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Sep  6 10:16:37 2018
@author: myhaspl
@email:[email protected]
"""

import tensorflow as tf
import os
g1=tf.Graph()

with g1.as_default(): 
    with tf.name_scope("input_Variable"):        
        my_var=tf.Variable(1,dtype=tf.float32)
    with tf.name_scope("global_step"):
        my_step=tf.Variable(0,dtype=tf.int32,trainable=False)
    with tf.name_scope("update"):
        varop=tf.assign(my_var,tf.multiply(tf.log(tf.add(my_var,1)),1))
        stepop=tf.assign_add(my_step,1)
        addop=tf.group([varop,stepop])
    with tf.name_scope("summaries"):
        tf.summary.scalar('myvar',my_var)
    with tf.name_scope("global_ops"):
        init=tf.global_variables_initializer()
        merged_summaries=tf.summary.merge_all()

with tf.Session(graph=g1) as sess:  
    writer=tf.summary.FileWriter('sum_vars',sess.graph)
    sess.run(init)

    saver=tf.train.Saver()

    #如果之前儲存了檢查點檔案,則恢復模型後,繼續
    init_step=0
    ckpt=tf.train.get_checkpoint_state(os.getcwd()+'/myvar-model')
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess,ckpt.model_checkpoint_path)
        init_step=int(ckpt.model_checkpoint_path.rsplit('-',1)[1])
        print "讀取檢查點檔案..."
    for i in xrange(init_step,100):
        step,var,summary=sess.run([my_step,my_var,merged_summaries])
        writer.add_summary(summary,global_step=step)
        print step,var,init_step
        if i%5==0 and i<=50:
            print "儲存檢查點檔案"
            saver.save(sess,'./myvar-model/myvar-model',global_step=i)
        sess.run(addop)

    writer.flush()
    writer.close()

上面程式碼跑第一次時,檢查點檔案被儲存,跑第二次開始,檢查點檔案將被讀取,迴圈次數從step=50開始。

跑第二次時

讀取檢查點檔案...
50 0.03925755 50
儲存檢查點檔案
51 0.038506564 50
52 0.037783686 50
53 0.03708737 50
54 0.036416177 50
55 0.035768777 50
56 0.03514393 50
...
...
...
93 0.021334965 50
94 0.02111056 50
95 0.02089082 50
96 0.0206756 50
97 0.020464761 50
98 0.020258171 50
99 0.020055704 50