1. 程式人生 > >【TensorFlow】tf.scatter_update()

【TensorFlow】tf.scatter_update()

在看tensorflow官網的API的時候,看到一個更新資料的函式。該函式的目的是為了能更新tensor的值,這個函式也解決了之前我想要更新tensor值的想法。在網上找了很多關於 tf.scatter_update() 的資料,但是找到的基本都是tensorflow官網上的API介紹和Stack Overflow上的提問,可見關於這個API的中文資料是相當少的,所以我打算寫下這篇部落格來介紹 tf.scatter_update()。

在這裡我簡短的介紹一下這個函式的使用:

tf.scatter_update

scatter_update(

ref, 

indices,

updates,

use_locking=None,

name=None 

)

在原始碼,函式的定義的位置在 tensorflow/python/ops/gen_state_ops.py.

引數介紹:

ref: 原來的tensor;

indices: 原來tensor中要更新的索引值,同樣也 tensor;

updates: 用於替代原來tensor的tensor值,注意,這個tensor和原來的tensor的shape要相同。

use_locking=None, name=None,一般情況下,我們使用預設的就好。

返回:依舊是一個tensor,shape和原來的tensor相同,是按照indices更新過tensor值的tensor;

介紹完了這個函式,那麼我來舉一個示例來讓大家明白怎麼去用這個函式。

程式碼如下:

import tensorflow as tf

g = tf.Graph()
with g.as_default():
    a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])
    b = tf.scatter_update(a, [0, 1], [[1, 1, 0, 0], [1, 0, 4, 0]])

with tf.Session(graph=g) as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(a))
    print(sess.run(b))
輸出:
[[0 0 0 0]
 [0 0 0 0]]
[[1 1 0 0]
 [1 0 4 0]]

我們能看到原來的tensor是

[[0 0 0 0]

 [0 0 0 0]]

更新tensor值後的tensor是

[[1 1 0 0]

 [1 0 4 0]]

總結:1、對於tf.scatter_update()來說,ref和updates的shape一定要相同,要不然會報錯;

   2、indices也是一個tensor,我們需要更新哪一維就寫哪一維;

   3、這樣的方式適合更新整個tensor的值,特別適合批量化更新tensor;