1. 程式人生 > >Tensorflow: 動態的給變數tf.Variable賦值【tf.assign】

Tensorflow: 動態的給變數tf.Variable賦值【tf.assign】

Motivation

錯誤:
tensorflow不能直接給Variable賦值,比如:

embedding_var = tf.Variable(1)
test_var = 10
embedding_var = test_var

或者:

embedding_var = tf.Variable(1)
init = tf.initialize_all_variables()
sess = tf.InteractiveSession()
sess.run(init)
x.assign(1)

解決方法

正確:
如果只需要給Variable賦值一次,可以通過assign這樣進行賦值:

import tensorflow as tf
x = tf.Variable(0)
y = tf.assign(x, 1)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(x)
    print sess.run(y)
    print sess.run(x)

但是通常賦一次值的意義不大,因為有時我們想將網路中的一些輸出通過saver()儲存下來,或者通過tensorboard檢視網路的embedding投影,那麼就需要將網路中產生的輸出以變數的形式儲存,這樣就可以在saver.save()的時候將這些輸出給儲存到本地,又因為tensorflow不能在圖外面直接對變數進行操作,所以我通過用一個佔位符來傳輸網路的輸出結果,然後再session裡面取出網路的輸出值,feed給該佔位符,然後將佔位符的值賦給一個臨時變數作為儲存,如下,親測有效:

flat_value = np.zeros([200,4*4*32]) 
mid_vari = tf.placeholder(tf.float32, [200,4*4*32],name="mid_vari")
embedding_var = tf.Variable(tf.zeros([200,4*4*32]), name=NAME_TO_VISUALISE_VARIABLE)
mid_vari_2 = tf.assign(embedding_var,mid_vari)

with tf.Session() as sess:
    saver =  tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    for i in range(200):
        flat_value,_=sess.run([flat,mid_vari_2],feed_dict={x:one_x,y:labels,mid_vari:flat_value})

比較周折,不過也是試了很多辦法才找到的解決方案T_T。

參考

https://blog.csdn.net/mustar_2017/article/details/79336679