1. 程式人生 > >Tensorflow——————API 使用方法記錄

Tensorflow——————API 使用方法記錄

 返回一個onehot tensor

tf.one_hot(
    indices,   
    depth,
    on_value=None,
    off_value=None,
    axis=None,
    dtype=None,
    name=None
)
  • indices: tensor型別值得索引.

  • depth: 代表one-hot得深度,可以理解為對應得類別數.

  • on_value: 一個標量值,索引位置上得值. (default: 1).

  • off_value: 索引之外的值. (default: 0)

  • axis:

    新的維度,也就是填充的所在軸 (default: -1,).

  • dtype: The data type of the output tensor.

  • name: op節點名稱 (optional).

一個N維的輸入,則輸出維度為N+1,意思就是:如果輸入是一個標量,則輸出為一個固定深度的向量;如果輸入是一個 固定深度的向量,則輸出是一個二維tensor。上述兩個輸入,又分兩個情況(axis=-1或者0)。

# 輸入為標量  
[features, depth] if axis == -1
[depth, features] if axis == 0

# 輸入為固定長度的向量
[batch, features x depth] if axis == -1
[batch, depth x features] if axis == 1
[depth, batch x features] if axis == 0

舉例如下:

# 標量
indices = 3
depth = 4
a = tf.one_hot(indices, depth, axis=-1) 
sess = tf.InteractiveSession()
print(sess.run(a))

# 輸出
# axis=-1或者0,結果都是一樣的,因為就一個維度,怎麼取都是那一個。。。。。
[0. 0. 0. 1.]

--------------------------------------------------------------------------------

# 輸入為固定長度向量
indices = [0, 1, 3, 2]
depth = 5
a = tf.one_hot(indices, depth, axis=-1) 
sess = tf.InteractiveSession()
print(sess.run(a))

# axis=-1時
[[1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 1. 0. 0.]]
# axis=0時
[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 0. 1.]
 [0. 0. 1. 0.]
 [0. 0. 0. 0.]]

當輸出為[batch, features]時(訓練的時候都是批次的),輸出維度應該是這樣的:

[batch, features, depth] if axis == -1
[batch, depth, features] if axis == 1
[depth, batch, features] if axis == 0

舉例:

# 對這種輸入的one-hot,我是一般不會遇到,平時模型訓練,label都是一維的
indices = [[0, 2], [1, -1], [2, 3]]
depth = 5
a = tf.one_hot(indices, depth, axis=-1) 
sess = tf.InteractiveSession()
print(sess.run(a))

# 輸出
# aixs=-1
[[[1. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0.]]

 [[0. 1. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 1. 0. 0.]
  [0. 0. 0. 1. 0.]]]

# axis=0
[[[1. 0.]
  [0. 0.]
  [0. 0.]]

 [[0. 0.]
  [1. 0.]
  [0. 0.]]

 [[0. 1.]
  [0. 0.]
  [1. 0.]]

 [[0. 0.]
  [0. 0.]
  [0. 1.]]

 [[0. 0.]
  [0. 0.]
  [0. 0.]]]

持續更新·············