1. 程式人生 > >tf.nn.embedding_lookup用法解釋

tf.nn.embedding_lookup用法解釋

Welcome to my blog
tf.nn.embedding_lookup( params, ids, …),主要使用params, ids兩個引數,函式的功能是從params中挑出索引為ids的元素,並返回一個張量,
假設params的shape是batch * hidden, ids的shape是batch * n
那麼函式返回張量的shape是batch *n * hidden

import tensorflow as tf
w = tf.constant([[1,2],[3,4],[5,6]])
res = tf.nn.embedding_lookup(w, [
0,1,0]) res2 = tf.nn.embedding_lookup(w, [[0],[1],[0]]) res4 = tf.nn.embedding_lookup(w, [[0,0,0,0],[1,1,1,1],[0,0,0,0]]) with tf.Session() as sess: res,res2,res4 = sess.run([res,res2,res4]) print res,res.shape print res2,res2.shape print res4,res4.shape ''' 列印結果 res: (3, 2) [[1 2] [3 4] [1 2]] res2:(3, 1, 2) [[[1 2]] [[3 4]] [[1 2]]] res4: (3, 4, 2) [[[1 2] [1 2] [1 2] [1 2]] [[3 4] [3 4] [3 4] [3 4]] [[1 2] [1 2] [1 2] [1 2]]] (3, 4, 2) '''