1. 程式人生 > >白話生成對抗網路 GAN,50 行程式碼玩轉 GAN 模型!【附原始碼】

白話生成對抗網路 GAN,50 行程式碼玩轉 GAN 模型!【附原始碼】

今天,紅色石頭帶大家一起來了解一下如今非常火熱的深度學習模型:生成對抗網路(Generate Adversarial Network,GAN)。GAN 非常有趣,我就以最直白的語言來講解它,最後實現一個簡單的 GAN 程式來幫助大家加深理解。

1. 什麼是 GAN?

好了,GAN 如此強大,那它到底是一個什麼樣的模型結構呢?我們之前學習過的機器學習或者神經網路模型主要能做兩件事:預測和分類,這也是我們所熟知的。那麼是否可以讓機器模型自動來生成一張圖片、一段語音?而且可以通過調整不同模型輸入向量來獲得特定的圖片和聲音。例如,可以調整輸入引數,獲得一張紅頭髮、藍眼睛的人臉,可以調整輸入引數,得到女性的聲音片段,等等。也就是說,這樣的機器模型能夠根據需求,自動生成我們想要的東西。因此,GAN 應運而生!

GAN,即生成對抗網路,主要包含兩個模組:生成器(Generative Model)和判別器(Discriminative Model)。生成模型和判別模型之間互相博弈、學習產生相當好的輸出。以圖片為例,生成器的主要任務是學習真實圖片集,從而使得自己生成的圖片更接近於真實圖片,以“騙過”判別器。而判別器的主要任務是找出出生成器生成的圖片,區分其與真實圖片的不同,進行真假判別。在整個迭代過程中,生成器不斷努力讓生成的圖片越來越像真的,而判別器不斷努力識別出圖片的真假。這類似生成器與判別器之間的博弈,隨著反覆迭代,最終二者達到了平衡:生成器生成的圖片非常接近於真實圖片,而判別器已經很難識別出真假圖片的不同了。其表現是對於真假圖片,判別器的概率輸出都接近 0.5。

對 GAN 的概念還是有點不清楚?沒關係,舉個生動的例子來說明。

最近,紅色石頭想學習繪畫,是因為看到梵大師的畫作,也想畫出類似的作品。梵大師的畫作像這樣:

這裡寫圖片描述

說畫就畫,紅色石頭找來一個研究梵大師作品很多年的王教授來指導我。王教授經驗豐富,眼光犀利,市面上模仿梵大師的畫作都難逃他的法眼。王教授跟我說了一句話:什麼時候你的畫這幅畫能騙過我,你就算是成功了。

紅色石頭很激動,立馬給王教授畫了這幅畫:

這裡寫圖片描述

王教授輕輕掃了一眼,滿臉黑線,氣的直哆嗦,“0 分!這也叫畫?差得太多了!” 聽了王教授的話,紅色石頭自我反省,確實畫的不咋地,連眼睛、鼻子都沒有。於是,又 重新畫了一幅:

這裡寫圖片描述

王教授一看,不到 2 秒鐘,就丟下四個字:1 分!重畫!紅色石頭一想,還是不行,畫得太差了,就回去好好研究梵大師的畫作風格,不斷改進,重新創作,直到有一天,紅色石頭拿著新的畫作給王教授看:

這裡寫圖片描述

王教授看了一看,說有點像了。我得仔細看看。最後,還是跟我說,不行不行,細節太差!繼續重新畫吧。唉,王教授越來越嚴格了!紅色石頭嘆了口氣回去繼續研究,最後將自我很滿意的一幅畫交給了王教授鑑賞:

這裡寫圖片描述

這下,王教授戴著眼鏡,仔細品析,許久之後,王教授拍著我的肩膀說,畫得很好,我已經識別不了真假了。哈哈,得到了王教授的誇獎和肯定,心裡美滋滋,終於可以創作出梵大師樣的繪畫作品了。下一步考慮轉行去。

好了,例子說完了(接受大家對我繪畫天賦的吐槽)。這個例子,其實就是一個 GAN 訓練的過程。紅色石頭就是生成器,目的就是要輸出一幅畫能夠騙過王教授,讓王教授真假難辨!王教授就是判別器,目的就是要識別出紅色石頭的畫作,判斷其為假的!整個過程就是“生成 — 對抗”的博弈過程,最終,紅色石頭(生成器)輸出一幅“以假亂真”的畫作,連王教授(判別器)都難以區分了。

這就是 GAN,懂了吧。

2. GAN 模型基本結構

在認識 GAN 模型之前,我們先來看一看 Yann LeCun 對未來深度學習重大突破技術點的個人看法:

The most important one, in my opinion, is adversarial training (also called GAN for Generative Adversarial Networks). This is an idea that was originally proposed by Ian Goodfellow when he was a student with Yoshua Bengio at the University of Montreal (he since moved to Google Brain and recently to OpenAI).

This, and the variations that are now being proposed is the most interesting idea in the last 10 years in ML, in my opinion.

Yann LeCun 認為 GAN 很可能會給深度學習模型帶來新的重大突破,是20年來機器學習領域最酷的想法。這幾年 GAN 發展勢頭非常強勁。下面這張圖是近幾年 ICASSP 會議上所有提交的論文中包含關鍵詞 “generative”、“adversarial” 和 “reinforcement” 的論文數量統計。

這裡寫圖片描述

資料表明,2018 年,包含關鍵詞 “generative” 和 “adversarial” 的論文數量發生井噴式增長。不難預見, 未來幾年關於 GAN 的論文會更多。

下面來介紹一下 GAN 的基本結構,我們已經知道了 GAN 由生成器和判別器組成,各用 G 和 D 表示。以生成圖片應用為例,其模型結構如下所示:

這裡寫圖片描述

GAN 基本模型由 輸入 Vector、G 網路、D 網路組成。其中,G 和 D 一般都是由神經網路組成。G 的輸出是一幅圖片,只不過是以全連線形式。G 的輸出是 D 的輸入,D 的輸入還包含真實樣本集。這樣, D 對真實樣本儘量輸出 score 高一些,對 G 產生的樣本儘量輸出 score 低一些。每次迴圈迭代,G 網路不斷優化網路引數,使 D 無法區分真假;而 D 網路也在不斷優化網路引數,提高辨識度,讓真假樣本的 score 有差距。

最終,經過多次訓練迭代,GAN 模型建立:

這裡寫圖片描述

最終的 GAN 模型中,G 生成的樣本以假亂真,D 輸出的 score 接近 0.5,即表示真假樣本難以區分,訓練成功。

這裡,重點要講解一下輸入 vector。輸入向量是用來做什麼的呢?其實,輸入 vector 中的每一維度都可以代表輸出圖片的某個特徵。比如說,輸入 vector 的第一個維度數值大小可以調節生成圖片的頭髮顏色,數值大一些是紅色,數值小一些是黑色;輸入 vector 的第二個維度數值大小可以調節生成圖片的膚色;輸入 vector 的第三個維度數值大小可以調節生成圖片的表情情緒,等等。

這裡寫圖片描述

GAN 的強大之處也正是在於此,通過調節輸入 vector,就可以生成具有不同特徵的圖片。而這些生成的圖片不是真實樣本集裡有的,而是即合理而又沒有見過的圖片。是不是很有意思呢?下面這張圖反映的是不同的 vector 生成不同的圖片。

這裡寫圖片描述

說完了 GAN 的模型之後,我們再來簡單看一下 GAN 的演算法原理。既然有兩個模組:G 和 D,每個模組都有相應的網路引數。

先來看 D 模組,它的目標是讓真實樣本 score 越大越好,讓 G 產生的樣本 score 越小越好。那麼可以得到 D 的損失函式為:

這裡寫圖片描述

其中,x 是真實樣本,G(z) 是 G 生成樣本。我們希望 D(x) 越大越好,D(G(z)) 越小越好,也就是希望 -D(x) 越小越好,-log(1-D(G(z))) 越小越好。從損失函式的角度來說,能夠得到上式。

再來看 G 模組,它的目標就是希望其生成的模型能夠在 D 中得到越高的分數越好。那麼可以得到 G 的損失函式為:

這裡寫圖片描述

知道了損失函式之後,接下來就可以使用各種優化演算法來訓練模型了。

3. 動手寫個 GAN 模型

接下來,我將使用 PyTorch 實現一個簡單的 GAN 模型。仍然以繪畫創作為例,假設我們要創造如下“名畫”(以正弦圖形為例):

這裡寫圖片描述

生成該“藝術畫作”的程式碼如下:

def artist_works():    # painting from the famous artist (real target)
   r = 0.02 * np.random.randn(1, ART_COMPONENTS)
   paintings = np.sin(PAINT_POINTS * np.pi) + r
   paintings = torch.from_numpy(paintings).float()
   return paintings

然後,分別定義 G 網路和 D 網路模型:

G = nn.Sequential(                  # Generator
   nn.Linear(N_IDEAS, 128),        # random ideas (could from normal distribution)
   nn.ReLU(),
   nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas
)

D = nn.Sequential(                  # Discriminator
   nn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like G
   nn.ReLU(),
   nn.Linear(128, 1),
   nn.Sigmoid(),                   # tell the probability that the art work is made by artist
)

我們設定 Adam 演算法進行優化:

opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

最後,構建 GAN 迭代訓練過程:

plt.ion()    # something about continuous plotting

D_loss_history = []
G_loss_history = []
for step in range(10000):
   artist_paintings = artist_works()          # real painting from artist
   G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideas
   G_paintings = G(G_ideas)                   # fake painting from G (random ideas)

   prob_artist0 = D(artist_paintings)         # D try to increase this prob
   prob_artist1 = D(G_paintings)              # D try to reduce this prob

   D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
   G_loss = torch.mean(torch.log(1. - prob_artist1))

   D_loss_history.append(D_loss)
   G_loss_history.append(G_loss)

   opt_D.zero_grad()
   D_loss.backward(retain_graph=True)    # reusing computational graph
   opt_D.step()

   opt_G.zero_grad()
   G_loss.backward()
   opt_G.step()

   if step % 50 == 0:  # plotting
       plt.cla()
       plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
       plt.plot(PAINT_POINTS[0], np.sin(PAINT_POINTS[0] * np.pi), c='#74BCFF', lw=3, label='standard curve')
       plt.text(-1, 0.75, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 8})
       plt.text(-1, 0.5, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 8})
       plt.ylim((-1, 1));plt.legend(loc='lower right', fontsize=10);plt.draw();plt.pause(0.01)

plt.ioff()
plt.show()

我採用了動態繪圖的方式,便於時刻觀察 GAN 模型訓練情況。

迭代次數為 1 時:

這裡寫圖片描述

迭代次數為 200 時:

這裡寫圖片描述

迭代次數為 1000 時:

這裡寫圖片描述

迭代次數為 10000 時:

這裡寫圖片描述

完美!經過 10000 次迭代訓練之後,生成的曲線已經與標準曲線非常接近了。D 的 score 也如預期接近 0.5。

完整程式碼有 .py 和 .ipynb 兩種版本,我已經放在了 GitHub 上,需要的請點選下面的連結獲取。

這裡寫圖片描述