1. 程式人生 > >tf.one_hot()使用

tf.one_hot()使用

tf.one_hot在看conditionGAN的時候注意到label的輸入要把它轉換成one-hot形式,再與噪聲z進行tf.concat輸入,之前看的時候忽略了,現在再看才算明白為什麼。

tf.one_hot(
    indices,#輸入,這裡是一維的
    depth,# one hot dimension.
    on_value=None,#output 預設1
    off_value=None,#output 預設0
    axis=None,#根據我的實驗,預設為1
    dtype=None,
    name=None
)

程式碼

import tensorflow as tf
import numpy as np
z=np.random.randint(0
,10,size=[10]) y=tf.one_hot(z,10,on_value=1,off_value=None,axis=0) with tf.Session()as sess: print(z) print(sess.run(y)) [5 7 7 0 5 5 2 0 0 0] [[0 0 0 1 0 0 0 1 1 1] [0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0] [1 0 0 0 1 1 0 0 0 0] [0 0 0 0 0 0 0 0 0 0] [0 1 1 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0]]
#!/usr/bin/env python3
# -*- coding: utf-8
-*- import tensorflow as tf import numpy as np import os os.environ["CUDA_VISIBLE_DEVICES"] = "2" z=np.random.randint(0,10,size=[10]) y=tf.one_hot(z,10,on_value=1,off_value=None) y1=tf.one_hot(z,10,on_value=1,off_value=None,axis=1) with tf.Session()as sess: print(z) print(sess.run(y)) print("axis=1按行排"
, sess.run(y1)) [6 3 4 9 6 5 5 1 2 1] [[0 0 0 0 0 0 1 0 0 0] [0 0 0 1 0 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 1] [0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 1 0 0 0 0] [0 1 0 0 0 0 0 0 0 0] [0 0 1 0 0 0 0 0 0 0] [0 1 0 0 0 0 0 0 0 0]] axis=1按行排 [[0 0 0 0 0 0 1 0 0 0] [0 0 0 1 0 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 1] [0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 1 0 0 0 0] [0 1 0 0 0 0 0 0 0 0] [0 0 1 0 0 0 0 0 0 0] [0 1 0 0 0 0 0 0 0 0]]

感覺實際用的時候可以不傳入axis值。可以看到經過one-hot的處理,輸入的維度變成了10×depth,值也變成了0和1.

下面說在condition GAN中要輸入標籤資訊y,怎樣處理的。
y是mnist的標籤值,0和10之間的整數,尺寸為[BATCH],經過one-hot處理後維度變成了[BATCH,10]值也是0和1,此時再與噪聲z按列(axis=1)連線,變成條件GAN的輸入。因此one-hot操作是必須的,這個處理在infoGAN中將z,categorical latent code、continuous latent code連線在一起輸入也要用到。

  y = tf.one_hot(y, 10, name='label_onehot')
   z = tf.random_uniform([BATCH, 100], -1, 1, name='z_train')
  tf.concat([z, y], 1)