1. 程式人生 > >tf.where()用法

tf.where()用法

找出tensor裡所有True值的index

import tensorflow as tf

a = tf.constant([False,False,True,False,True],dtype=tf.bool)
b = tf.where(tf.equal(a,True))
sess = tf.Session()
print(sess.run(b))

print:
[[2]
[4]]