1. 程式人生 > >tensorflow的一些程式碼分析(五) tensorflow模型儲存和視覺化

tensorflow的一些程式碼分析(五) tensorflow模型儲存和視覺化

儲存與讀取模型

在使用tf來訓練模型的時候,難免會出現中斷的情況。這時候自然就希望能夠將辛辛苦苦得到的中間引數保留下來,不然下次又要重新開始。好在tf官方提供了儲存和讀取模型的方法。

儲存模型的方法:

# 之前是各種構建模型graph的操作(矩陣相乘,sigmoid等等....)

saver = tf.train.Saver() # 生成saver

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) # 先對模型初始化

    # 然後將資料丟入模型進行訓練blablabla

    # 訓練完以後,使用saver.save 來儲存
saver.save(sess, "save_path/file_name") #file_name如果不存在的話,會自動建立
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

將模型儲存好以後,載入也比較方便,如下所示:

saver = tf.train.Saver()

with tf.Session() as sess:
    #引數可以進行初始化,也可不進行初始化。即使初始化了,初始化的值也會被restore的值給覆蓋
    sess.run(tf.global_variables_initializer())     
    saver.restore(sess, "save_path/file_name"
) #會將已經儲存的變數值resotre到 變數中。
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

簡單的說,就是通過saver.save來儲存模型,通過saver.restore來載入模型

使用tensorboard來使訓練過程視覺化

tensorflow還提供了一個視覺化工具,叫tensorboard.啟動以後,可以通過網頁來觀察模型的結構和訓練過程中各個引數的變化。如下圖所示

選區_059.png-12.7kB

關於如何合理清楚的顯示網路結構,我目前還不太搞得清楚,而且目前看來也不是太重要;但是要將訓練的過程視覺化還是比較方便的。簡單的說,流程如下所示:

  • 使用tf.scalar_summary來收集想要顯示的變數
  • 定義一個summury op, 用來彙總多個變數
  • 得到一個summy writer,指定寫入路徑
  • 通過summary_str = sess.run()
# 1. 由之前的各種運算得到此批資料的loss
loss = ..... 

# 2.使用tf.scalar_summary來收集想要顯示的變數,命名為loss
tf.scalar_summary('loss',loss)  

# 3.定義一個summury op, 用來彙總由scalar_summary記錄的所有變數
merged_summary_op = tf.merge_all_summaries()

# 4.生成一個summary writer物件,需要指定寫入路徑,例如我這邊就是/tmp/logdir
summary_writer = tf.train.SummaryWriter('/tmp/logdir', sess.graph)

# 開始訓練,分批喂資料
for(i in range(batch_num)):
    # 5.使用sess.run來得到merged_summary_op的返回值
    summary_str = sess.run(merged_summary_op)

    # 6.使用summary writer將執行中的loss值寫入
    summary_writer.add_summary(summary_str,i)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

接下來,程式開始執行以後,跑到shell裡執行

$ tensorboard --logdir /tmp/logdir
  • 1
  • 1

開始執行tensorboard.接下來開啟瀏覽器,進入127.0.0.1:6006 就能夠看到loss值在訓練中的變化值了。