1. 程式人生 > >原始GAN論文筆記及TensorFlow實現

原始GAN論文筆記及TensorFlow實現

Welcome To My Blog

引言

  • 在GAN誕生之前,比起生成模型而言,判別模型更受關注,比如Alex Net,VGG,Google Net,因為典型的生成模型往往具有原理複雜,推導複雜,實現複雜的特點
  • 對於生成模型而言,通常有兩種建模方式
    • 最常見的是對目標物件的概率分佈建模,將其表達成具體的某種引數形式,再通過最大似然一類的方法訓練模型,如深度玻爾茲曼機DBM,這樣做的缺點:通常得到的似然函式無法直接求解,需要藉助近似演算法或者取樣演算法
    • 採用非引數的方式建模,如GSN,方法核心:假設一條馬爾科夫鏈的穩態分佈是資料的真實分佈,然後將馬爾科夫鏈的求解操作替換為可以用梯度反向傳播來執行的操作
  • GAN作為一種訓練框架,由兩個網路Generator和Discriminator構成,D採用判別式準則輔助訓練生成模型G,結構如下,X是真實資料,Z是隨機噪聲,Z經過Generator後成為X’;X’和X作為Discriminator的輸入,Discriminator根據X判斷X’是不是真實的資料,並將結果反饋給Generator.GAN目的就是希望X’儘可能地接近X,也就是P_g = P_data
    2.png

GAN的兩個網路

Generator

G本來就是做生成的,比如Auto Encoder就是一種生成模型,GAN為什麼要增加D呢?因為只用G有缺陷,以AE為例,AE側重於生成與原圖片儘可能相似的圖片,但這樣會犧牲掉圖片中各個component之間的聯絡,如下圖所示
1.png


對於AE來說,output1更像原圖,但是我們寫數字時,筆畫的中間往往不會有空缺,也就是說,雖然output2最後的筆畫拉長了,但比起output1來說更自然,因為output2更care各個component之間的聯絡.
當然了,AE可以通過增加神經網路的層數使得網路可以考慮這種聯絡,但是生成相同質量圖片的情況下,GAN的結構更加簡單.

Discriminator

D雖然是判別模型,但也可以做生成,需要解下面這個式子
3.png
也就是對於給定的輸入x,遍歷所有可能的資料,挑出分數最高的圖片作為生成結果.但是首先需要假設D(x)的形式,如果假設D(x)是線性的,那麼模型的能力太弱;如果假設是非線性的,又不好解argmax.
在GAN中,D的輸入是G的輸出,G的輸出是一張完整的圖片,D對一張完整的圖片進行判別可以很好地catch到各個component之間的聯絡,然後將這個資訊反饋給G,從而使G生成具有大局觀的圖片

GAN的數學推導

IanGoodfellow的論文Generative Adversarial Nets是這樣引出GAN的目標函式的
對於Discriminator來說,它用來判斷輸入的資料是真還是假,具體做法是:對真實的資料賦予高分,對虛假的資料賦予低分;也就是希望賦予D(X)高分,賦予D(X’)低分,可以寫成如下的形式
4.png
+ 取1-D(X’)是為了滿足對數的定義域要求
+ 取對數,個人認為是為了湊alog(x)+blog(1-x)的形式,之後會提到
+ 取期望是把分佈P_data和P_g考慮進來
對於Generator來說,希望自己生成的資料X’更接近真實資料X,也就是希望D(X’)越大越好,這便體現了G與D的博弈思想,結合G與D的初衷可得目標函式為:
5.png

目標函式的有效性

優化V(D,G)後,等價於實現了P_g = P_data,下面說明原因:

固定G,優化D

首先直接使用一個概率論中的定理:
6.png
將V(D,G)展開
7.png
最後一步合併了兩個積分,從而擴大了積分限,兩個被積函式在無定義處取0即可
剛才提到為什麼目標函式採用對數形式,原因如下
8.png
目標函式正好符合上述定理形式,所以固定G,優化D時D的最優值為:
9.png

固定D,優化G

10.png
當P_g = P_data時,上面的不等式取等號,C(G)取得最小值,說明按照上面的方式優化目標函式,效果相當於P_g = P_data,說明了GAN的可行性

優化流程

11.png
優化D的時候優化了k次,不過論文中實驗的時候取k=1
在優化G的初期,由於G生成的資料X’很假,所以log(1-D(G(z)))的梯度接近1,有點小,不利於迭代,所以會使用max_G log(D(G(z)))優化G

TensorFlow實現

完整程式碼可以參考深度學習-GAN專題程式碼復現中的”GAN的誕生”.
如果對logistic regression和交叉熵有一定的認識會對理解程式碼實現有很大幫助
1. 關於交叉熵,可以參考交叉熵與KL散度
2. 關於logistic regression,可以參考Logistic Regression邏輯斯蒂迴歸
3. TF文件中關於logistic loss的解釋
12.png

# 輸入噪聲從正態分佈中取樣得到
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)
# Generator
def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)

    return G_prob
# Discriminator
def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)

    return D_prob, D_logit

個人總結

  1. GAN是一種框架,核心思想是對抗訓練:針對D,希望賦予D(X)高分,賦予D(X’)低分;針對G,希望賦予D(G(z))高分.
  2. 這種對抗訓練思想的有效性是通過求解下面的目標函式實現的,求解結果是P_G=P_data
    5.png
  3. 程式碼實現時,只要能夠體現GAN的核心思想即可,使用TensorFlow實現原始GAN模型時,由於TF有simoid_cross_entropy_with_logits這個函式,所以可以使用logistic regression對X和X’進行二分類.此時最大化樣本構成的似然函式,相當於最小化樣本標籤和D輸出之間的交叉熵.

最後推薦一下楊雙老師的課程,深度學習-GAN專題論文研讀,老師講得非常棒

參考:
楊雙:深度學習-GAN專題論文研讀
李巨集毅對抗生成網路
統計學習方法