1. 程式人生 > >tf.argmax的axis理解

tf.argmax的axis理解

import tensorflow as tf
tf.enable_eager_execution()

value = [[0, 1, 2, 3],
         [4, 5, 6, 7]]
init = tf.constant_initializer(value)
x = tf.get_variable('x', shape=[2,4], initializer=init)

print(tf.argmax(x,axis=0)) # 列
print(tf.argmax(x,axis=1)) # 行

列印結果:
tf.Tensor([1 1 1 1], shape=(4,), dtype=int64)
tf.Tensor([3 3], shape=(2,), dtype=int64)