1. 程式人生 > >GAN (Generative Adversarial Network)

GAN (Generative Adversarial Network)

簡單函數 理解 code amp 技術 系列 整體 最終 理論

https://www.bilibili.com/video/av9770302/?p=15

前面說了auto-encoder,VAE可以用於生成

VAE的問題,

AE的訓練是讓輸入輸出盡可能的接近,所以生成出來圖片只是在模仿訓練集,而無法生成他完全沒有見過的,或新的圖片

由於VAE並沒有真正的理解和學習如何生成新的圖片,所以對於下面的例子,他無法區分兩個case的好壞,因為從lost上看都是比7多了一個pixel

技術分享圖片

所以產生GAN,

大家都知道GAN是對抗網絡,是generator和discriminator的對抗,對抗是有一個逐漸進化的過程

過程是,

我們通過V1的generator的輸出和real images來訓練V1的discriminator,讓V1的discriminator可以判別出兩者的差別

然後,將V1的generator和V1的discriminator作為整體network訓練(這裏需要固定discriminator的參數),目標就是讓generator產生的圖片可以騙過V1的discriminator

這樣就產生出V2的generator,重復上面的過程,讓generator和discriminator分別逐漸進化

技術分享圖片

訓練Discriminator的詳細過程,

技術分享圖片

訓練generator的詳細過程,

可以看到 generator會調整參數,產生image讓discriminator判別為1,即騙過discriminator

並且在網絡訓練的時候,雖然是把generator和discriminator合一起訓練,但是要fix住discriminator的參數,不然discriminator只需要簡單的迎合generator就可以達到目標,起不到對抗的效果

技術分享圖片

下面從理論上來看下GAN,

GAN的目的是生成和目標分布(訓練集所代表的分布)所接近的分布

技術分享圖片

Pdata就是訓練數據所代表的分布

PG是我們要生成的分布

所以我們的目標就是讓PG和Pdata盡可能的close

從Pdata中sample任意m個點,然後用這些點去計算PG,用最大似然估計,算likelihood

讓這些點在PG中的概率和盡可能的大,就會讓PG分布接近Pdata

技術分享圖片

這裏的推導出,上面給出的最大似然估計,等價於求Pdata和PG的KL散度,這個是make sense的,KL散度本身就用來衡量兩個分布的相似度

這裏PG可以是任意函數,比如,你可以用高斯混合模型來生成PG,那麽theta就是高斯混合中每個高斯的參數和weight

那麽這裏給定參數和一組sample x,我們就可以用混合高斯的公式算出PG,根據上面的推導,也就得到了兩個分布的KL散度

當然高斯混合模型不夠強大,很難很好的去擬合Pdata

所以這裏是用GAN的第一個優勢,我們可以用nn去擬合PG

技術分享圖片

這個圖就是GAN的generator,z符合高斯分布,z是什麽分布不關鍵也可以是其他分布

通過Gz函數,得到x,z可以從高斯分布中sample出很多點,所以計算得到很多x,x的分布就是PG;只要nn足夠復雜,雖然z的分布式高斯,但x可以是任意分布

這裏和傳統方法,比如高斯混合的不同是,這個likelihood,即PG不好算,因為這裏G是個nn,所以我們沒有辦法直接計算得到兩個分布的KL散度

所以GAN需要discriminator,它也是一個nn,用discriminator來間接的計算PG和Pdata的相似性,從而替代KL散度的計算

技術分享圖片

GAN可以分成Generator G和Discriminator D,其中D是用來衡量PG和Pdata的相似性

技術分享圖片

最終優化目標的公式,看著很唬人,又是min,又是max

其實分成兩個步驟,

給定G,優化D,使得maxV(紅線部分),就是訓練discriminator,計算出兩個分布之間的差異值;在上圖中就是在每個小圖裏找到那個紅點

給定D,優化G,使得min(maxV),就是在訓練generator,最小化兩個分布之間的差異;就是在上圖中挑選出G3

這裏有個問題沒有講清楚的是,

為何給定G,優化D,使得maxV,得到的V可以代表兩個分布的差異?

如果這個問題明白了,下一步優化G,去最小化這個分布間的差異是很好理解的

技術分享圖片

做些簡單的轉換,如果我們要最後一步這個積分最大,那麽等價於對於每個x,積分的內容都最大

技術分享圖片

這裏是給定G,x,Pdata(x),PG(x)都是常量,所以轉換成D的一個簡單函數

求最大值,就極值,就是求導找到極點

這裏推導出當V max的時候, D的定義,並且D的值域應該在0到1之間

技術分享圖片技術分享圖片

上面推導出如果要Vmax,D要滿足

技術分享圖片

所以進一步將D帶入V的公式,這裏經過一系列推導得到,V就等價於jensen-shannon divergence

jensen-shannon divergence的定義,如下,

技術分享圖片

比KL divergence好的是,KL是非對稱的,而jensen-shannon divergence是對稱的,可以更好的反應兩個分布間的差異

那麽這裏的推導就證明,給定G,優化D讓V最大的時候,V就表示Pdata和PG的jensen-shannon divergence,所以這個Vmax就可以表示這個兩個分布的差異,也就回答了前面的問題

GAN (Generative Adversarial Network)