1. 程式人生 > >GAN背後的理論依據,以及為什麼只使用GAN網路容易產生

GAN背後的理論依據,以及為什麼只使用GAN網路容易產生

花了一下午研究的文章,解答了我關於GAN網路的很多疑問,內容的理論水平很高,只能儘量理解,但真的是一篇非常好的文章轉自http://www.dataguru.cn/article-10570-1.html

GAN回顧

 

Martin 稱這個loss為original cost function(參見[1] 2.2.1章節),而實際操作中採用的loss為the –log D cost(參見[1] 2.2.2章節)。

 

GAN存在的問題:初探

 

當固定G時,訓練D直到收斂,可以發現D的loss會越來越小,趨於0,這表明JSD(Pr || Pg)被較大化了,並且趨於log2。如下圖所示。

而這會導致什麼問題呢?在實踐中人們發現,當D訓練得更較精確,G的更新會變得越差,訓練變得異常地不穩定。為什麼會產生這些這樣的問題?之前一直沒有人給出解答。

 

JSD(Pr || Pg)達到較大化,有兩種可能:

概率分佈不是()連續的,也就是說,它沒有密度函式。我們常見的分佈一般都有密度函式。如果概率分佈是定義在一個低維的流形上(維度低於全空間),那它就不是連續的。

 

分佈是連續的,但是兩者的支撐集沒有交集。兩個分佈的支撐集不外乎包含以下四種情形:

經過計算可以發現(參見下期推送),case1 的JSD小於log2,case2的JSD等於log2,case3和4的JSD也不超過log2。實際上這很好理解,兩個分佈差異越大,交叉熵越大。

 

這是比較直觀的解釋,更進一步地,作者從理論上進行了嚴格的分析和證明。

 

GAN存在的問題:理論分析

 

Lemma 1:設

是一個由仿射變換和逐點定義的非線性函式(ReLU、leaky ReLU或者諸如sigmoid、tanh、softplus之類的光滑嚴格遞增函式)複合得到的複合函式,則g(Z)包含在可數多個流形的並集中,並且它的維數至多為dim(Z)。因此,若dim(Z) < dim(X),則g(Z)在X中測度為0。

 

Lemma1表明,若generator(G)是一個神經網路,並且G的輸入(隨機高斯噪聲)的維數比產生的影象的維數低,則無論怎樣訓練,G也只能產生整個影象空間中很小的部分,有多小呢?它在影象空間中只是一個零測集。零測集是什麼概念呢,舉些例子,二維空間中的一條直線、三維空間中的一個平面。二維平面在三維空間中是沒有“厚度”的,它的體積是0。

 

我們訓練GAN時,訓練集總歸是有限的,Pr一般可以看成是低維的流形;如果Pg也是低維流形,或者它與Pr的支撐集沒有交集,則在discriminator (D)達到最優時,JSD就會被較大化。D達到最優將導致G的梯度變得異常地差,訓練將變得異常不穩定。

下面的幾個定理、引理就是在證明Pg在上述兩種情況下,最優的D是存在的。

 

Theorem2.1: 若分佈Pr和Pg的支撐集分別包含在兩個無交緊緻子集M和P中,則存在一個光滑的最優discriminator D*: X -> [0,1],它的精度是1,並且,對任意的

定理2.1是什麼意思呢?如果兩個概率分佈的支撐集沒有交集,則完美的D總是存在的,並且(在兩個分佈的支撐集的並集上)D的梯度為0,也就是說,這時候任何梯度演算法都將失效。這就是GAN訓練的時候,(在兩個概率分佈的支撐集沒有交集的情況下)當D訓練效果很好的時候,G的更新將變得很差的原因。

 

Lemma2: 設M和P是R^d的兩個非滿維度的正則子流形,再設η 和 η’ 是任意的兩個獨立連續隨機變數,定義兩個擾動流形M’ = M + η,P’ = P + η’,則

Lemma 2是為定理2.2做準備,它表明任意兩個非滿維的正則子流形都可以通過微小的擾動使得它們不是完美對齊(notperfectly align)的,即它們的交點都是橫截相交(intersect transversally)的。橫截相交和完美對齊的嚴謹定義將在下期推送中給出,在這裡形象地說明一下:

 

橫截相交(intersect transversally):對兩個子流形,在任意一個交點處,兩者的切平面能夠生成整個空間,則稱兩個子流形橫截相交。當然,如果它們沒有交集,則根據定義,它們天然就是橫截相交的。下圖給出了一個示例,在交點P處,平面的切平面是其自身,直線的切平面也是其自身,它們可以張成全空間,因此是橫截相交的,而兩個直線沒辦法張成全空間,因此不是橫截相交的;如果兩個流形是相切的,在切點處它們的切平面是相同的,也不可能張成全空間,因此也不是橫截相交的。

完美對齊(perfectly align): 如果兩個子流形有交集,並且在某個交點處,它們不是橫截相交的。

 

Pr和Pg的支撐集如果存在交集,那麼根據lemma2,我們總可以通過微小的擾動使得它們不是完美對齊的,也就是說,它們是橫截相交的。

 

Lemma3: 設M和P是R^d的兩個非完美對齊,非滿維的正則子流形,L=M∩P,若M和P無界,則L也是一個流形,且維數嚴格低於M和P的維數。若它們有界,則L是至多四個維數嚴格低於全空間的流形之並。無論哪種情形,L在M或者P中的測度均為0。

 

Lemma 3說的是,兩個正則子流形(滿足一定條件:非完美對齊,非滿維)的交集的維數要嚴格低於它們自身的維數。也就是說,它們的交集只是冰山一角,小到相對它們自身可以忽略。對於Pr和Pg的支撐集來說,根據Lemma 2,我們總可以通過微小擾動使得它們是非完美對齊的,在根據Lemma 3,Pr和Pg的交集是微不足道。

 

Theorem2.2: 設Pr和Pg分別是支撐集包含在閉流形M和P中的兩個分佈,且它們非完美對齊、非滿維。進一步地,我們假設Pr和Pg在各自的流形中分別連續,即:若集合A在M中測度為0,則Pr(A) = 0(在Pg上也有類似結果)。則存在精度為1的最優discriminator D*: X->[0,1],並且幾乎對M 或者P中的任意點x,D*在x的任意鄰域內總是光滑的,且有

定理2.1證明的是對於Pr和Pg無交的情形下,最優的discriminator是存在的。定理2.2承接Lemma 3,它證明了在Pr和Pg的支撐集有交集,且橫截相交的情況下,最優的discriminator是存在的。這兩個定理實際上把兩種可能導致D最優,且梯度消失的情形在理論上做出證明,由於梯度的消失,G的更新將得不到足夠的梯度,導致G很差。

 

Theorem2.3: 在定理2.2的條件下,有

定理2.3表明,隨著D越來越好,D的loss將越來越小,趨於0,因此Pr和Pg的JSD被較大化,達到較大值log2,這時,Pr和Pg的交叉熵達到無窮大,也就是說,即使兩個分佈之間的差異任意地小,它們之間的KL散度仍然會被較大化,趨於無窮。這是什麼意思呢?利用KL散度來衡量分佈間的相似性在這裡並不是一個好的選擇。因此,我們有必要尋求一個更好的衡量指標。

 

定理2.4 探究了generator在前面所述情況下回出現什麼問題,它從理論上給出了,若G採用original cost function(零和博弈),那麼它的梯度的上界被D與最優的D*之間的距離bound住。說人話就是,我們訓練GAN的時候,D越接近最優的D*,則G的梯度就越小,如果梯度太小了,梯度演算法不能引導G變得更好。下圖給出了G的梯度變化(固定G,訓練D),注意到隨著訓練的進行,D將變得越來越較精確,這時G的梯度強度將變得越來越小,與理論分析符合。

定理2.5研究了G的loss為the –logD cost時將會出現的問題。我們可以看到,當JSD越大時,G的梯度反而會越小,也就是說,它可能會引導兩個分佈往相異的方向,此外,上式的KL項雖對產生無意義影象會有很大的懲罰,但是對mode collapse懲罰很小,也就是說,GAN訓練時很容易落入區域性最優,產生mode collapse。KL散度不是對稱的,但JSD是對稱的,因此JSD並不能改變這種狀況。這就是我們在訓練GAN時經常出現mode collapse的原因。

 

定理2.6告訴我們,若G採用the –logD cost,在定理2.1或者2.2的條件下,當D與D*足夠接近時,G的梯度會呈現強烈震盪,這也就是說,G的更新會變得很差,可能導致GAN訓練不穩定。

 

下圖給出了定理2.6的實驗模擬的效果,在DCGAN尚未收斂時,固定G,訓練D將導致G的梯度產生強烈震盪。當DCGAN收斂時,這種震盪得到有效的抑制。

既然general GAN採用的loss不是一種好的選擇,有什麼loss能夠有效避免這種情形嗎?

 

一個臨時解決方案

一個可行的方案是打破定理的條件,給D的輸入新增噪聲。後續的幾個定理對此作了回答。

定理3.1和推論3.1表明,ε的分佈會影響我們對距離的選擇。

定理3.2證明了G的梯度可以分為兩項,第一項表明,G會被引導向真實資料分佈移動,第二項表明,G會被引導向概率很高的生成樣本遠離。作者指出,上述的梯度格式具有一個很嚴重的問題,那就是由於g(Z)是零測集,D在優化時將忽略該集合;然而G卻只在該集合上進行優化。進一步地,這將導致D極度容易受到生成樣本的影響,產生沒有意義的樣本。

對D的輸入新增噪聲,在訓練的過程中將引導噪聲樣本向真實資料流形的方向移動,可以看成是引導樣本的一個小鄰域向真實資料移動。這可以解決D極度容易受到生成樣本的影響的問題。

Wasserstein距離

定義3.1: X上的兩個分佈P、Q的Wasserstein度量W(P,Q)定義為

其中,Г是X×X上所有具有邊界分佈P和Q的聯合分佈集。

 

Wasserstein距離通常也稱為轉移度量或者EM距離(地動距離),它表示從一個分佈轉移成另一個分佈所需的最小代價。下圖給出了一個離散分佈下的例子,將f1(x)遷移成f2(x)最小代價即是移動f1(x)在較大值處的2個單位的概率到最小值處,這樣就得到了分佈f2(x)。更復雜的離散情形需要通過求解規劃問題。

定理3.3告訴我們一個有趣的事實,上式右邊兩項均能被控制。第一項可以通過逐步減小噪聲來逐步減小;第二項可以通過訓練GAN(給D的輸入新增噪聲)來最小化。

 

作者指出,這種通過給D的輸入新增噪聲的解決方案具有一大好處,那就是我們不需要再擔心訓練過程。由於引入了噪聲,我們可以訓練D直到最優而不會遇到G的梯度消失或者訓練不穩定的問題,此時G的梯度可以通過推論3.2給出。

 

總而言之,WGAN的前傳從理論上研究了GAN訓練過程中經常出現的兩大問題:G的梯度消失、訓練不穩定。並且提出了利用地動距離來衡量Pr和Pg的相似性、對D的輸入引入噪聲來解決GAN的兩大問題,作者證明了地動距離具有上界,並且上界可以通過有效的措施逐步減小。

 

這可以說是一個臨時性的解決方案,作者甚至沒有給出實驗進行驗證。在WGAN[2]這篇文章中,作者提出了更完善的解決方案,並且做了實驗進行驗證。下面我們就來看一下這篇文章。

 

Wasserstein GAN

 

常見距離

Martin Arjovsky在WGAN論文進一步論述了為什麼選擇Wasserstein距離(地動距離)。

設X是一個緊緻度量空間,我們這裡討論的影象空間([0,1]^d)就是緊緻度量空間。用Σ表示X上的所有博雷爾集,用Prob(X)表示定義在X上的概率度量空間。給定Prob(X)上的兩個分佈Pr, Pg,我們可以定義它們的距離/散度(請注意:散度不是距離,它不是對稱的。距離和散度都可以用於衡量兩個分佈的相似程度):

表示以Pr, Pg為邊緣分佈的所有聯合分佈組成的集合。

我們用一個簡單的例子來看一下這四種距離/散度是怎麼計算的。

 

考慮下圖的兩個均勻分佈:

二維平面上,P1是沿著y軸的[0,1]區間上的均勻分佈,P2是沿著x=θ,在y軸的[0,1]區間上的均勻分佈。簡而言之,你可以把P1和P2看成是兩條平行的線段。容易計算

當θ->0時,W->0,然而TV距離、KL散度、JS散度都不收斂。也就是說,地動距離對某些情況下要更合理一些。更嚴謹的結論由下面的定理給出。

PS: 為了統一編號,後續的定理編號與原文[2]的編號不一樣,兩種編號相差3...

 

 

3. 上述兩個結論對JS散度和KL散度均不成立。

 

定理4表明,地動距離與JS散度、KL散度相比,具有更好的性質。

定理5表明,如果分佈的支撐集在低維流形上,KL散度、JS散度和TV距離並不是好的loss,而地動(EM)距離則很合適。這啟發我們可以用地動距離來設計loss以替換原來GAN採用的KL散度。

WGAN

採用Wasserstein距離作為loss的GAN稱為WassersteinGAN,一般簡寫為WGAN。直接考慮Wasserstein距離需要算inf,計算是很困難的。考慮它的Kantorovich-Rubinstein對偶形式

可以看到,如果把GAN的目標函式的log去掉,則兩者只相差一個常數,也就是說,WGAN在訓練的時候與GAN幾乎一樣,除了loss計算的時候不取對數!Loss function中的對數函式導致了GAN訓練的不穩定!

定理6證明了若D和G的學習能力足夠強的話(因此目標函式能夠被較大化),WGAN是有解的。WGAN的演算法流程如下:

WGAN實驗

作者發現,如果WGAN訓練採用SGD或者RMSProp演算法,則收斂效果很好。一般不採用基於momentum的演算法,如Adam演算法,實現觀察發現這類優化演算法會導致訓練變得不穩定。而我們知道,DCGAN採用Adam演算法進行優化效果會比較好。這是WGAN與GAN訓練方法的差別。

 

此外,WGAN當前的loss(Wasserstein距離)能夠用於指示訓練的效果,即G產生的影象質量,Wasserstein距離越小,G產生的影象質量就越高。先前的GAN由於訓練不穩定,我們很難通過loss去判斷G產生的質量(先前的GAN的loss大小並不能表明產生影象質量的高低)。這個發現對於訓練GAN有很大的幫助。

 

此外,WGAN如果採用DCGAN架構去訓練,產生的影象質量效果與DCGAN沒有明顯差異;並且,即使generator採用MLP(多層感知機),仍然能夠產生質量不錯的影象。實驗結果如下圖所示(圖中的曲線不會跟GAN一樣產生強烈震盪了!)。

此外,作者指出,WGAN的實驗中並沒有發現mode collapse!

WGAN程式碼及材料

 

Reddit討論區傳送門:

https://www.reddit.com/r/MachineLearning/comments/5qxoaz/r_170107875_wasserstein_gan/?from=groupmessage

 

推薦一篇用更通俗易懂的語言介紹WGAN的文章:

https://zhuanlan.zhihu.com/p/25071913

 

WGAN原始碼,作者提供,Torch版本:

https://github.com/martinarjovsky/WassersteinGAN

Tensorflow版本:https://github.com/Zardinality/WGAN-tensorflow

Keras版本:

https://github.com/tdeboissiere/DeepLearningImplementations/tree/master/WassersteinGAN

 

參考文獻

Arjovsky, M., & Bottou, L.eon. (2017). Towards Principled Methods for Training Generative AdversarialNetworks.

Arjovsky, M., Soumith, C.,& Bottou, L. eon. (n.d.). Wasserstein GAN.