1. 程式人生 > >GAN原理解析,公式推導與python實現

GAN原理解析,公式推導與python實現

1-生成模型

1-1 生成模型與判別模型

生成式對抗網路,顧名思義就是生成模型嘛!那什麼是生成模型呢?與判別模型有什麼區別呢?

先來理解一下判別模型。
學過機器學習的人都知道,入門級的演算法邏輯迴歸,最後的預測,是通過sigmoid函式:
這裡寫圖片描述
生成一個0-1之間的數值,然後用某一閥值來做分類,我們稱之為判別模型:由資料直接學習,通過決策函式Y=f(X)或者概率模型P(Y|X)預測,判斷類別的模型。邏輯迴歸中的sigmoid函式就是判別函式。

而生成模型,則先學習出一個聯合概率密度分佈P(X,Y),最後分類,預測的時候,使用條件概率公式

P(Y|X)=P(X,Y)P(X)來預測當前樣本屬於每一類的概率。
生成模型的核心,就是先求兩個比較好求的概率,然後通過貝葉斯概率公式這樣的關係來進行分類。

簡單的看,
這裡寫圖片描述

1-2 為什麼學習生成模型?

我們可以總結一下它的優點:

  1. 能生成高維的資料或者複雜的概率分佈,且高維資料分佈在數學和工業界都扮演著重要的作用
  2. 還可以為強化學習做一定準備
  3. 對於缺失資料較多的場景,可以用來生成更多的樣例資料,是當前用來解決資訊缺失的最好方式。

舉一些例子:
Next Video Frame Prediction
這裡寫圖片描述
大概意思是預測下一幀會是什麼?比如第一個頭像是當前幀的狀態,然後給出一個MSE(均方差矩陣),預測下一幀會出現什麼,如圖就是頭轉了大約15度或者30度左右的樣子。

很明顯,這個應用是很強大的,比如某些打碼的片子,甚者打碼的圖片,三級片之類的,或者有損失的古物,古畫,都有可能通過生成模型,來生成新的。

Single Image Super-Resolution
這裡寫圖片描述
或者處理比較模糊的圖片,因為畫素太低,導致人看了不怎麼清楚,可以通過GAN來生成更高畫素的圖片,看的更清晰。上圖左為原始圖片,第二張為 使用bicubic method的插值法得出的,第三張是使用ResNet,第四張是使用GAN來生成的。

Image to Image Translation
這裡寫圖片描述
根據你畫的樣子,給你生成一個你可能想都沒想過的樣子,或者根據地圖,生成場景之類的。

總之,GAN的特點之一,就是生成,生成一些你可能想都沒想過的東西。

1-3 生成模型原理—似然原理

這裡寫圖片描述
所謂的生成模型,其實就是基於最大似然估計的,而最大似然估計就是用的似然原理。

什麼是似然原理呢?我們舉個例子,比如你要估計一個學校的數學成績是多少,肯定不會直接找全校的學生,然後再把他們的數學成績放在一起計算吧。因為這樣做的代價太大了,現實情況根本不允許。

那我們該怎麼辦呢?
那我們可以隨機取樣嘛,用樣本來估計總體。什麼意思呢?我們知道一個群體的某一形狀假如服從正態分佈,如圖所示,那麼這個分佈的形狀由兩方面決定,分別是均值μ和方差σ2,只要這兩個定了,那麼影象的概率密度分佈也就定了。

那我們能否用樣本(抽到的學生)的均值和方差,估計總體(整個學校的學生)的均值和方差,這樣不就得到了總體的概率分佈了嗎?其實這就是生成模型的原理。
想要詳細瞭解它的公式推導,可以看一下這篇的最後一個部分:概率統計學習基礎

2-生成式對抗網路

2-1 生成式對抗網路工作原理

前面我們說的都是前期知識準備,瞭解生成模型,接下來看一下真正的生成式對抗網路,它的工作流如下:
這裡寫圖片描述
GAN是一種structured probabilistic model,具體介紹在deep learning這本書的第16章有。
顧名思義,生成-對抗,其核心也是兩個,一個是生成,用圖中的G(z)表示,z是隱變數;一個是對抗,用圖中的D(x)表示,x是觀測變數。圖結構表示如下:

GAN是一個有向圖模型,它的每一個隱變數都在影響觀測變數。

這裡寫圖片描述
我們希望達到的效果,就如上圖所示,生成式對抗網路會訓練並更新判別分佈(D,圖中藍色虛線部分),希望能將真實的分佈(pdata(x),)和生成分佈(pmodel(G(z)),)區分開來。z表示屬於某一分佈的噪聲資料,經過生成器之後,變成x=G(z),得到一個生成分佈pmodel(G(z))。而上方的x水平線則代表真實的分佈X中的一部分。

主要目標就兩個:

  • 判別器D(x)獨自訓練自己,希望能分辨出真實的資料分佈和生成器給的資料分佈
  • 生成器G(z)也訓練自己,希望以假亂真,讓判別器判別不出到底哪個是真,哪個是假

G(z)D(x)它們的區別主要在於輸入的資料和引數:θ(G),θ(D).

我們輸入noise樣本給G(z),讓它生成一張圖片,因為是noise樣本,生成的圖片效果應該不會太好,然後交給D(x)D(x)說你生成的圖是錯的,或者生成的圖是正確的,但D(x)卻認為你G(z)生成的是錯的,也就是說D(x)出現了判斷錯誤,這就會產生誤差,自己就可以用梯度下降法來優化自己;

如果D(x)判斷你生成的不對,那G(z)就知道自己哪裡不對了,也可以用梯度下降法來優化自己,以生成更好的圖片。

  • 舉個例子,這個過程就像是一個畫畫的老師在教,或者說監督學生畫畫一樣,老師就是對抗的部分,學生就是生成部分。老師手上有一副真的蒙娜麗莎,學生手上有一副蒙娜麗莎的贗品,學生在模仿贗品畫蒙娜麗莎,然後給老師看,老實說你這裡畫得不像,重來,然後不斷改進,直到學生畫得蒙娜麗莎老師分辨不出來到底是真是假是,訓練停止。

判別部分(Generator Network)

輸入z,通過對應的引數,能生成x,x屬於生成的概率分佈:Pmodel,用公式表示如下:

x=G(z;θ(G))如果我們想要生成x和生成分佈Pmodel相吻合,就得要求輸入的資料z的維度至少是和x的維度相差無幾的。

訓練過程
GAN的訓練過程也是使用SGD,或者其他的優化演算法來做優化,一般使用minibatch,值得注意的是,我們可以讓其中一個先跑一下,一般讓生成部分先跑, 使得生成的效率更高一點。

2-2 判別器的損失函式

一個演算法模型最重要的,莫過於損失函式和優化的方法了,對於判別部分來說,所有的GAN變體都是使用相同的損失函式J(D)定義,他們之間不同之處在於生成部分選擇的損失函式J(G)不同。

首先我們必須先定義樣本是從真實分佈pdata(x)來的,還是從生成分佈pmodel(x)中來的。因此有:

ExpdatalogD(x)這裡的E表示期望。這一項是根據正類(即能判別出x屬於真實分佈)的對數損失函式構建的,最大化這一項意味著令判別器D在x服從於data的概率密度時能準確預測D(x)= 1,即:D(x)=1whenxpdata(x)

另一項是根據負類的對數損失函式構建的:

Ezpmodellog(1D(G(z)))D(x)=0whenxp