論文解讀:DeLiGAN: Generative Adversarial Networks for Diverse and Limited Data
前言:DeLiGAN是計算機視覺頂會CVPR2017發表的一篇論文,本文將結合Python原始碼學習DeLiGAN中的核心內容。DeLiGAN最大的貢獻就是將生成對抗網路(GANs)的輸入潛空間編碼為混合模型(高斯混合模型),從而使得生成對抗網路(GANs)在數量有限但具有多樣性的訓練資料上表現出較好的效能;同時,在初始得分(Inception Score)的基礎上提出了改進版初始得分(modified Inception Score)用於評測生成樣本的類內多樣性。
論文地址:https://arxiv.org/abs/1706.02071
原始碼地址:https://github.com/val-iisc/deligan
一、概述
(一)關鍵科學問題:DeLiGAN要解決的關鍵科學問題是傳統的生成對抗網路(GANs)需要大量的訓練資料才能學習到跨圖片模態的多樣性(diversity across the image modality,根據全文內容,感覺這裡理解為資料集中包含很多不同類別影象,生成對抗網路要學習到每一個類別的類內多樣性比較困難),在訓練資料有限且類內、類間具有多樣性的情況下,傳統的生成對抗網路(GANs)的效果並不好。
(二)解決方法:將生成對抗網路中的輸入潛空間引數化為混合模型,通過對抗訓練的方式學習生成對抗網路(GANs)的引數以及該潛空間混合模型的引數,從而提高生成對抗網路(GANs)對有限數量的、且具有類內多樣性的訓練資料的學習能力。
(三)幾個概念的說明
(1)intra-class diversity:類內多樣性(inter-class:類間多樣性)
(2)modified version of inception score:改進版初始得分(m-IS)
二、核心方法解讀(結合Python原始碼)
(一)輸入噪聲取樣的潛空間
將GANs的輸入噪聲取樣潛空間重構為高斯混合模型
其中,表示高斯分佈下噪聲樣本z的概率
因為無法從訓練中得到每一個高斯分佈的權重(也就是上面公式中的),所以將所有的高斯分佈的權重置為,則上面公式變為:
為了取樣噪聲樣本,從個高斯分佈中隨機選取一個,利用“重引數化技巧(reparameterization trick)”,將噪聲樣本表示為被選中的第
Toy原始碼中的zin表示的是,zsig表示,從正太分佈或均勻分佈中取樣
這樣,噪聲樣本取樣就轉換為從高斯分佈取樣,
因為訓練生成器的原來目標是
訓練生成器的新目標就變為:
最終,通過對抗訓練學習生成對抗網路(GANs)的引數的同時,根據生成器的損失反饋到混合模型的梯度來訓練高斯模型的引數和,
引數初始化設定:從均勻分佈隨機取樣,為非零初始值0.2。
同時,因為上面公式中具有區域性最優,為了防止生成器在訓練過程中為了生成更多的高概率區域的樣本而不斷減小趨於0,在生成器的損失函式中引入正則化項,則生成器的新的損失函式的公式表示為:
(二)改進版初始得分(Inception Score)
初始得分(Inception Score):將生成影象輸入到一個訓練好的具有inception結構的分類器中就會得到一個條件標籤分佈,當足夠真實時,就會得到一個多峰(peaky)的標籤分佈,即應該具有較低的熵值。同時,我們也希望生成的圖片能涵蓋所有的類別,即應該具有較高的熵值。這兩個要求整合為初始得分(Inception Score)的測量標準:
該公式表示形式使模型具有較高初始分值但卻導致低熵條件類別分佈,但是我們需要每個類別的影象都具有多樣性,因此使用交叉熵作為訓練好的Inception模型的輸出結果,因此改進版初始得分為:
三、Python原始碼
dg_toy.py
import tensorflow as tf
import numpy as np
import os
import time
from random import randint
import cv2
import matplotlib.pylab as Plot
batchsize=50
results_dir='../results/toy'
def linear(x,output_dim, name="linear"):
""" Linear Layer for 2d input x """
w=tf.get_variable(name+"/w", [x.get_shape()[1], output_dim])
b=tf.get_variable(name+"/b", [output_dim], initializer=tf.constant_initializer(0.0))
return tf.matmul(x,w)+b
# 判別器設定1個隱層(含32個神經元),即df_dim=32
def discriminator(image, reuse=False):
""" Discriminator function description """
with tf.variable_scope('disc', reuse=reuse):
h0 = tf.tanh(linear(image,df_dim,'d_l1'))
h1 = linear(h0, 1, 'd_l2')
return tf.nn.sigmoid(h1), h1
# 生成器設定1個隱層(含32個神經元),即gf_dim=32
def generator(z, n):
""" Generator function description """
with tf.variable_scope('gen'+str(n)):
#z = tf.tanh(linear(z, batchsize,'g_l0')) # Uncomment for testing GAN++ model
h1 = tf.tanh(linear(z, gf_dim,'g_l1'))
h2 = linear(h1, 2, 'g_l2')
return tf.nn.tanh(h2)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.25)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
imageshape = [2]
z_dim = 2
gf_dim = 32 # 生成器全連線隱層的神經元數量
#gf_dim = 32*50 # Uncomment this line for testing Nx-GAN
df_dim = 32 # 判別器全連線隱層的神經元數量
learningrate = 0.0001
beta1 = 0.5
# Taking Inputs for the graph as placeholders
images = tf.placeholder(tf.float32, [batchsize] + imageshape, name="real_images")
z = tf.placeholder(tf.float32, [None, z_dim], name="z")
lr1 = tf.placeholder(tf.float32, name="lr")
zin = tf.get_variable("g_z", [batchsize, z_dim],initializer=tf.random_uniform_initializer(-1,1))
# zin對應論文公式(8)中的\mu_{i}
zsig = tf.get_variable("g_sig", [batchsize, z_dim],initializer=tf.constant_initializer(0.02))
# zsig對應論文公式(8)中的\sigma_{2},論文中說是0.2
inp = tf.add(zin,tf.mul(z,zsig)) #Uncomment this line for testing the DeliGAN
# inp也就是轉化為混合模型後,取樣得到的最終要輸入到生成器中的z
#moe = tf.eye(batchsize) #Uncomment this line for testing the MoE-GAN
#inp = tf.concat_v2([moe, z],1) #Uncomment this line for testing the MoE-GAN
# Calling the generator Function for different Models
#G = generator(z[:1],0) #Uncomment this line when testing Ensemble-GAN
G = generator(inp,0) #Uncomment this line for testing DeliGAN, MoE-GAN
#G = generator(z,0) #Uncomment this line for testing GAN and Nx-GAN
#for n in range(batchsize-1): #Uncomment this line when testing Ensemble-GAN
# g = generator(z[n+1:n+2],n+1) #Uncomment this line when testing Ensemble-GAN
#G = tf.concat_v2([g,G],0) #Uncomment this line when testing Ensemble-GAN
lab = tf.where(G[:,0]<0)
D_prob, D_logit = discriminator(images)
D_fake_prob, D_fake_logit = discriminator(G, reuse=True)
# Defining Losses
sig_loss = 0.1*tf.reduce_mean(tf.square(zsig-1))
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logit, tf.ones_like(D_logit)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_fake_logit, tf.zeros_like(D_fake_logit)))
gloss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_fake_logit, tf.ones_like(D_fake_logit)))
gloss1 = gloss+sig_loss
dloss = d_loss_real + d_loss_fake
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'd_' in var.name]
g_vars = [var for var in t_vars if 'g_' in var.name]
# 隨機生成兩類符合正太分佈的訓練資料
data = np.random.normal(0,0.3,(200,2)) # Comment this line when using multimodal (i.e. Uncomment for unimodal data)
data1 = np.random.normal(0,0.3,(200,2)) # Comment this line when using multimodal (i.e. Uncomment for unimodal data)
#data = np.random.normal(0.6,0.15,(200,2)) # Uncomment this line for multimodal data
#data1 = np.random.normal(-0.6,0.15,(200,2)) # Uncomment this line for multimodal data
data = np.vstack((data,data1))
data = data.reshape([-1,2])
# Optimization
d_optim = tf.train.AdamOptimizer(lr1, beta1=beta1).minimize(dloss, var_list=d_vars)
g_optim = tf.train.AdamOptimizer(lr1, beta1=beta1).minimize(gloss1, var_list=g_vars)
tf.initialize_all_variables().run()
saver = tf.train.Saver(max_to_keep=10)
counter = 1
start_time = time.time()
data_size = data.shape[0]
display_z = np.random.normal(0, 1.0, [batchsize, z_dim]).astype(np.float32) #Uncomment this line for using a mixture of normal prior
#display_z = np.random.uniform(-1.0, 1.0, [batchsize, z_dim]).astype(np.float32) #Uncomment this line for using a mixture of uniform distributions prior
seed = 1
rng = np.random.RandomState(seed)
train = True
thres=1.0
count=0
t1=0.73
for epoch in xrange(8000):
batch_idx = data_size/batchsize
batch = data[rng.permutation(data_size)]
if count<-1000:
t1=max(t1-0.005, 0.70)
lr = learningrate
for idx in xrange(batch_idx):
batch_images = batch[idx*batchsize:(idx+1)*batchsize]
batch_z = np.random.normal(0, 1.0, [batchsize, z_dim]).astype(np.float32)
batch_z = np.random.uniform(-1.0, 1.0, [batchsize, z_dim]).astype(np.float32)
# 這裡的batch_z相當於論文中公式(8)的\epsilon,論文是同從正太分佈N(0,1)中取樣,這裡會被後面的batch_z覆蓋掉
# Threshold to decide the which phase to run (generator or discrminator phase)
if count>10:
thres=min(thres+0.01, 1.0)
count=0
if count<-150 and thres>t1:
thres=max(thres-0.001, t1)
count=0
# Training each phase based on the value of thres and gloss
for k in xrange(5):
if gloss.eval({z: batch_z})>thres:
sess.run([g_optim],feed_dict={z: batch_z, lr1:lr})
count+=1
else:
sess.run([d_optim],feed_dict={ images: batch_images, z: batch_z, lr1:lr })
count-=1
counter += 1
# Printing training status periodically
if counter % 300 == 0:
print("Epoch: [%2d] [%4d/%4d] time: %4.4f, " % (epoch, idx, batch_idx, time.time() - start_time,))
sdata = sess.run(G,feed_dict={ z: display_z })
errD_fake = d_loss_fake.eval({z: display_z})
errD_real = d_loss_real.eval({images: batch_images})
errG = gloss.eval({z: display_z})
sl = sig_loss.eval({z: display_z})
print('D_real: ', errD_real)
print('D_fake: ', errD_fake)
print('G_err: ', errG)
print('zloss: ', sl)
# Plotting the generated samples and the training data
if counter % 1000 == 0:
f, (ax1,ax2, ax3) = Plot.subplots(1, 3)
ax1.set_autoscale_on(False)
ax2.set_autoscale_on(False)
lab1 = lab.eval({z:display_z})
gen = G.eval({z:display_z})
ax1.scatter(gen[:,0], gen[:,1]);
#ax1.scatter(gen[lab1,0], gen[lab1,1], color='r'); # Uncomment this line when testing with multimodal data
ax1.set_title('Generated samples')
ax1.set_aspect('equal')
ax1.axis([-1,1,-1,1])
ax2.scatter(batch[:,0], batch[:,1])
lab_ = batch[batch[:,0]<-0.1]
#ax2.scatter(lab_[:,0], lab_[:,1], color='r'); # Uncomment this line when testing with multimodal data
ax2.set_title('Training samples')
ax2.set_aspect('equal')
ax2.axis([-1,1,-1,1])
f.savefig(results_dir + '/plot' + str(counter) + ".png")
saver.save(sess, os.getcwd()+results_dir+'/train/',global_step=counter)