1. 程式人生 > >GAN完整理論推導與實現,Perfect!

GAN完整理論推導與實現,Perfect!

本文是機器之心第二個 GitHub 實現專案,上一個 GitHub 實現專案為從頭開始構建卷積神經網路。在本文中,我們將從原論文出發,藉助 Goodfellow 在 NIPS 2016 的演講和臺大李弘毅的解釋,完成原 GAN 的推導、證明與實現。

本文主要分四部分,第一部分描述 GAN 的直觀概念,第二部分描述概念與優化的形式化表達,第三部分將對 GAN 進行詳細的理論推導與分析,最後我們將實現前面的理論分析。

  • GitHub專案地址:https://github.com/jiqizhixin/ML-Tutorial-Experiment

本文更注重理論與推導,更多生成對抗網路的概念與應用請參閱:

生成對抗網路基本概念

要理解生成對抗模型(GAN),首先要了解生成對抗模型可以拆分為兩個模組:一個是判別模型,另一個是生成模型。簡單來說就是:兩個人比賽,看是 A 的矛厲害,還是 B 的盾厲害。比如,我們有一些真實資料,同時也有一把隨機生成的假資料。A 拼命地把隨手拿過來的假資料模仿成真實資料,並揉進真實資料裡。B 則拼命地想把真實資料和假資料區分開。

這裡,A 就是一個生成模型,類似於造假幣的,一個勁地學習如何騙過 B。而 B 則是一個判別模型,類似於稽查警察,一個勁地學習如何分辨出 A 的造假技巧。

如此這般,隨著 B 的鑑別技巧越來越厲害,A 的造假技巧也是越來越純熟,而一個一流的假幣制造者就是我們所需要的。雖然 GAN 背後的思想十分直觀與樸素,但我們需要更進一步瞭解該理論背後的證明與推導。

總的來說,Goodfellow 等人提出來的 GAN 是通過對抗過程估計生成模型的新框架。在這種框架下,我們需要同時訓練兩個模型,即一個能捕獲資料分佈的生成模型 G 和一個能估計資料來源於真實樣本概率的判別模型 D。生成器 G 的訓練過程是最大化判別器犯錯誤的概率,即判別器誤以為資料是真實樣本而不是生成器生成的假樣本。因此,這一框架就對應於兩個參與者的極小極大博弈(minimax game)。在所有可能的函式 G 和 D 中,我們可以求出唯一均衡解,即 G 可以生成與訓練樣本相同的分佈,而 D 判斷的概率處處為 1/2,這一過程的推導與證明將在後文詳細解釋。

當模型都為多層感知機時,對抗性建模框架可以最直接地應用。為了學習到生成器在資料 x 上的分佈 P_g,我們先定義一個先驗的輸入噪聲變數 P_z(z),然後根據 G(z;θ_g) 將其對映到資料空間中,其中 G 為多層感知機所表徵的可微函式。我們同樣需要定義第二個多層感知機 D(s;θ_d),它的輸出為單個標量。D(x) 表示 x 來源於真實資料而不是 P_g 的概率。我們訓練 D 以最大化正確分配真實樣本和生成樣本的概率,因此我們就可以通過最小化 log(1-D(G(z))) 而同時訓練 G。也就是說判別器 D 和生成器對價值函式 V(G,D) 進行了極小極大化博弈:

我們後一部分會對對抗網路進行理論上的分析,該理論分析本質上可以表明如果 G 和 D 的模型複雜度足夠(即在非引數限制下),那麼對抗網路就能生成資料分佈。此外,Goodfellow 等人在論文中使用如下案例為我們簡要介紹了基本概念。

如上圖所示,生成對抗網路會訓練並更新判別分佈(即 D,藍色的虛線),更新判別器後就能將資料真實分佈(黑點組成的線)從生成分佈 P_g(G)(綠色實線)中判別出來。下方的水平線代表取樣域 Z,其中等距線表示 Z 中的樣本為均勻分佈,上方的水平線代表真實資料 X 中的一部分。向上的箭頭表示對映 x=G(z) 如何對噪聲樣本(均勻取樣)施加一個不均勻的分佈 P_g。(a)考慮在收斂點附近的對抗訓練:P_g 和 P_data 已經十分相似,D 是一個區域性準確的分類器。(b)在演算法內部迴圈中訓練 D 以從資料中判別出真實樣本,該迴圈最終會收斂到 D(x)=p_data(x)/(p_data(x)+p_g(x))。(c)隨後固定判別器並訓練生成器,在更新 G 之後,D 的梯度會引導 G(z)流向更可能被 D 分類為真實資料的方向。(d)經過若干次訓練後,如果 G 和 D 有足夠的複雜度,那麼它們就會到達一個均衡點。這個時候 p_g=p_data,即生成器的概率密度函式等於真實資料的概率密度函式,也即生成的資料和真實資料是一樣的。在均衡點上 D 和 G 都不能得到進一步提升,並且判別器無法判斷資料到底是來自真實樣本還是偽造的資料,即 D(x)= 1/2。

上面比較精簡地介紹了生成對抗網路的基本概念,下一節將會把這些概念形式化,並描述優化的大致過程。

概念與過程的形式化

理論完美的生成器

該演算法的目標是令生成器生成與真實資料幾乎沒有區別的樣本,即一個造假一流的 A,就是我們想要的生成模型。數學上,即將隨機變數生成為某一種概率分佈,也可以說概率密度函式為相等的:P_G(x)=P_data(x)。這正是數學上證明生成器高效性的策略:即定義一個最優化問題,其中最優生成器 G 滿足 P_G(x)=P_data(x)。如果我們知道求解的 G 最後會滿足該關係,那麼我們就可以合理地期望神經網路通過典型的 SGD 訓練就能得到最優的 G。

最優化問題

正如最開始我們瞭解的警察與造假者案例,定義最優化問題的方法就可以由以下兩部分組成。首先我們需要定義一個判別器 D 以判別樣本是不是從 P_data(x) 分佈中取出來的,因此有:

其中 E 指代取期望。這一項是根據「正類」(即辨別出 x 屬於真實資料 data)的對數損失函式而構建的。最大化這一項相當於令判別器 D 在 x 服從於 data 的概率密度時能準確地預測 D(x)=1,即:

另外一項是企圖欺騙判別器的生成器 G。該項根據「負類」的對數損失函式而構建,即:

因為 x<1 的對數為負,那麼如果最大化該項的值,則需要令均值 D(G(z))≈0,因此 G 並沒有欺騙 D。為了結合這兩個概念,判別器的目標為最大化:

給定生成器 G,其代表了判別器 D 正確地識別了真實和偽造資料點。給定一個生成器 G,上式所得出來的最優判別器可以表示為 (下文用 D_G*表示)。定義價值函式為:

然後我們可以將最優化問題表述為:

現在 G 的目標已經相反了,當 D=D_G*時,最優的 G 為最小化前面的等式。在論文中,作者更喜歡求解最優化價值函的 G 和 D 以求解極小極大博弈:

對於 D 而言要儘量使公式最大化(識別能力強),而對於 G 又想使之最小(生成的資料接近實際資料)。整個訓練是一個迭代過程。其實極小極大化博弈可以分開理解,即在給定 G 的情況下先最大化 V(D,G) 而取 D,然後固定 D,並最小化 V(D,G) 而得到 G。其中,給定 G,最大化 V(D,G) 評估了 P_G 和 P_data 之間的差異或距離。

最後,我們可以將最優化問題表達為:

上文給出了 GAN 概念和優化過程的形式化表達。通過這些表達,我們可以理解整個生成對抗網路的基本過程與優化方法。當然,有了這些概念我們完全可以直接在 GitHub 上找一段 GAN 程式碼稍加修改並很好地執行它。但如果我們希望更加透徹地理解 GAN,更加全面地理解實現程式碼,那麼我們還需要知道很多推導過程。比如什麼時候 D 能令價值函式 V(D,G) 取最大值、G 能令 V(D,G) 取最小值,而 D 和 G 該用什麼樣的神經網路(或函式),它們的損失函式又需要用什麼等等。總之,還有很多理論細節與推導過程需要我們進一步挖掘。

理論推導

在原 GAN 論文中,度量生成分佈與真實分佈之間差異或距離的方法是 JS 散度,而 JS 散度是我們在推導訓練過程中使用 KL 散度所構建出來的。所以這一部分將從理論基礎出發再進一步推導最優判別器和生成器所需要滿足的條件,最後我們將利用推導結果在數學上重述訓練過程。這一部分為我們下一部分理解具體實現提供了強大的理論支援。

KL 散度

在資訊理論中,我們可以使用夏農熵(Shannon entropy)來對整個概率分佈中的不確定性總量進行量化:

如果我們對於同一個隨機變數 x 有兩個單獨的概率分佈 P(x) 和 Q(x),我們可以使用 KL 散度(Kullback-Leibler divergence)來衡量這兩個分佈的差異:

在離散型變數的情況下,KL 散度衡量的是,當我們使用一種被設計成能夠使得概率分佈 Q 產生的訊息的長度最小的編碼,傳送包含由概率分佈 P 產生的符號的訊息時,所需要的額外資訊量。

KL 散度有很多有用的性質,最重要的是它是非負的。KL 散度為 0 當且僅當 P 和 Q 在離散型變數的情況下是相同的分佈,或者在連續型變數的情況下是 『幾乎處處』 相同的。因為 KL 散度是非負的並且衡量的是兩個分佈之間的差異,它經常 被用作分佈之間的某種距離。然而,它並不是真的距離因為它不是對稱的:對於某些 P 和 Q,D_KL(P||Q) 不等於 D_KL(Q||P)。這種非對稱性意味著選擇 D_KL(P||Q) 還是 D_KL(Q||P) 影響很大。

在李弘毅的講解中,KL 散度可以從極大似然估計中推導而出。若給定一個樣本資料的分佈 P_data(x) 和生成的資料分佈 P_G(x;θ),那麼 GAN 希望能找到一組引數θ使分佈 P_g(x;θ) 和 P_data(x) 之間的距離最短,也就是找到一組生成器引數而使得生成器能生成十分逼真的圖片。

現在我們可以從訓練集抽取一組真實圖片來訓練 P_G(x;θ) 分佈中的引數 θ 使其能逼近於真實分佈。因此,現在從 P_data(x) 中抽取 m 個真實樣本 {x^1,x^2,…,x^