1. 程式人生 > >【TensorFlow-windows】學習筆記七——生成對抗網路

【TensorFlow-windows】學習筆記七——生成對抗網路

前言

既然學習了變分自編碼(VAE),那也必須來一波生成對抗網路(GAN)。

國際慣例,參考網址:

理論

粗略點的講法就說:一個生成器G,一個判別器D,前者用來將噪聲輸入轉換成圖片,後者判別當前輸入圖片是真實的還是生成的。

為了更詳細地瞭解GAN,還是對論文進行簡要的組織、理解吧。有興趣直接看原始論文,這裡進行部分關鍵內容的摘抄。

任意的GD函式空間都存在特定解,G要能表示訓練集的分佈,而D一定是等於12,也就是說判別器無法分辨當前輸入是真的還是假的,這樣就達到了魚目混珠的效果。在GAN中,使用多層感知器構建GD,整個模型可以使用反向傳播演算法學習。

論文裡面有一句很好的話解釋了GAN的動機:目前深度學習在判別模型的設計中取得了重大成功,但是在生成模型卻鮮有成效,主要原因在於在極大似然估計和相關策略中有很多難以解決的概率計算難題(想想前一篇部落格的變分自編碼的理論,闊怕),而且丟失了生成背景下的分段線性單元的優勢,因此作者就提出了新的生成模型估計方法,避開這些難題,也就是傳說中的GAN。它的訓練完全不需要什麼鬼似然估計,只需要使用目前炒雞成功的反傳和dropout演算法。

為了讓生成器學到資料分佈pg,需要定義一個先驗的噪聲輸入pz(z),然後使用G(z;θg)將其對映到資料空間,這裡的G是具有引數θ

g的多層感知器。然後定義另一個多層感知器D(x;θd),輸出一個標量。D(x)代表的是x來自於真實樣本而非生成的樣本pg的概率,我們訓練D去最大化將正確標籤同時賦予訓練集和G生成的樣本的概率,也就是D把真的和假的圖片都當成真的了。同時要去訓練G去最小化log(1D(G(z))),是為了讓生成的圖片被賦予正樣本標籤的概率大點,損失函式就是:

minGmaxDV(D,G)=Expdata(x)[logD(x
)]+Ezpz(z)[log(1D(G(z)))]

在優化D的時候,在訓練的內迴圈中是無法完成的,計算上不允許,並且在有限資料集上會導致過擬合,因此可以以k:1 的訓練次數比例分別優化DG。這能夠讓D保持在最優解附近,只要G變化比較緩慢。

而且在實際中,上式可能無法提供足夠的梯度讓G很好地學習,在訓練早期,當G很差的時候,D能夠以很高的概率將其歸為負樣本,因為生成的資料與訓練資料差距很大,這樣log(1D(G(z)))就飽和了,與其說最小化log(1D(G(z)))不如去最大化log(D(G(z))),這個目標函式對GD的收斂目標不變,但是能早期學習具有更強的梯度。

訓練演算法:

外層一個大迴圈就不說了,對所有的批資料迭代,內層有一個小迴圈,控制上面說的判別器D與生成器G的訓練比例為k:1的:

  • 以下步驟執行k次:

    • 從噪聲先驗pg(z)中取樣m個噪聲樣本