1. 程式人生 > >生成對抗式網路GAN 的 loss

生成對抗式網路GAN 的 loss

GAN同時要訓練一個生成網路(Generator)和一個判別網路(Discriminator),前者輸入一個noise變數 z ,輸出一個偽圖片資料 G(z;θg),後者輸入一個圖片(real image)以及偽圖片(fake image)資料 x ,輸出一個表示該輸入是自然圖片或者偽造圖片的二分類置信度 D(x;θd),理想情況下,判別器D 需要儘可能準確的判斷輸入資料到底是一個真實的圖片還是某種偽造的圖片,而生成器G又需要盡最大可能去欺騙D,讓D把自己產生的偽造圖片全部判斷成真實的圖片。
根據上述訓練過程的描述,我們可以定義一個損失函式:

Loss=1mmi=1[logD(xi)+l

og(1D(G(zi)))]

其中xi,zi 分別是真實的圖片資料以及noise變數。
而優化目標則是:

minGmaxDLoss

不過需要注意的一點是,實際訓練過程中並不是直接在上述優化目標上對 θd,θg 計算梯度,而是分成幾個步驟:

訓練判別器即更新θd:迴圈k次,每次準備一組real image資料 x=x1,x2,,xm 和一組fake image資料z=z1,z2,,zm,計算
θd1mmi=1[logD(xi)+log(1D(G(zi)))]
然後梯度上升法更新 θd
訓練生成器即更新 θg :準備一組fake image資料 z=z1,z2,,

zm ,計算
θg1mmi=1log(1D(G(zi)))
然後梯度下降法更新 θg
可以看出,第一步內部有一個k層的迴圈,某種程度上可以認為是因為我們的訓練首先要保證判別器足夠好然後才能開始訓練生成器,否則對應的生成器也沒有什麼作用,然後第二步求提督時只計算fake image那部分資料,這是因為real image不由生成器產生,因此對應的梯度為0。