1. 程式人生 > >生成對抗網路(Generative Adversarial Networks,GAN)初探

生成對抗網路(Generative Adversarial Networks,GAN)初探

1. 從納什均衡(Nash equilibrium)說起

我們先來看看納什均衡的經濟學定義:

所謂納什均衡,指的是參與人的這樣一種策略組合,在該策略組合上,任何參與人單獨改變策略都不會得到好處。換句話說,如果在一個策略組合上,當所有其他人都不改變策略時,沒有人會改變自己的策略,則該策略組合就是一個納什均衡。

B站上有一個關於”海灘2個兄弟賣雪糕“形成納什均衡的視訊,講的很生動。 不管系統中的雙方一開始處於什麼樣的狀態,只要系統中參與競爭的個體都是”理性經濟人“,即每個人在考慮其他人的可能動作的基礎上,出於最大化自己的個人利益作為下一步行動的考慮,那麼最終系統都一定會進入納什均衡狀態,這個狀態也許對於系統來說不是全域性最優的,但是對於系統中的每一個個體來說都是理論最優的。 這樣可能還是有一些抽象,我們用幾個例子來說明。

0x1:價格戰中的納什均衡

市場上有2家企業A和B,都是賣紙的,紙的成本都是2元錢,A和B都賣5塊錢。在最開始,A、B企業都是盈利3塊,這種狀態叫”社會最優解(Social optimal solution)“。但問題是,社會最優解是一個不穩定的狀態,就如同下圖中這個優化曲面上那個紅球點一樣,雖然該小球目前處於曲面最高點,但是隻要施加一些輕微的擾動,小球就會立刻向山下滑落:

現在企業A和B準備開展商業競爭:

  • 有一天,A企業率先降價到4塊錢,於是A銷量大增,B銷量大減。
  • B看到了後,降價到3塊錢,於是B銷量大增,A銷量大減。
  • ......

但如果價格戰一直這樣打下去,這個過程顯然不可能無限迭代下去。當A和B都降價到了3塊時,雙方都達到了成本的臨界點,既不敢漲價,也不敢降價。漲價了市場就丟了,降價了,就賺不到錢甚至賠錢。所以A和B都不會再去做改變,這就是納什均衡。

A和B怎樣能夠獲得最大利潤呢,就是A和B坐到一起商量,同時把價格提高,這就叫共謀,但法律為了保障消費者利益,禁止共謀。補充一句,共謀在機器學習中被稱作”模型坍塌“,指的對對抗的模型雙方都進入了一個互相認可的區域性最優區而不再變化,具體的技術細節我們後面會討論。

0x2:囚徒困境中的納什均衡

囚徒困境是說:有兩個小偷集體作案,然後被警察捉住。警察對兩個人分別審訊,並且告訴他們政策:

  • 如果兩個人都交代坦白,就可以定罪,兩個人各判八年。
  • 如果一個人交代另一個不交代,那麼一樣可以定罪。但是交代的人從寬處罰,批評教育就釋放。不交代的人從嚴處罰,判十年。
  • 如果兩個人都不交代,沒法定罪,每個人判一年意思一下。

兩個人的收益情況如下所示:

因為A和B是不能互相通訊的,因此這是一個靜態不完全資訊博弈,我們分別考慮雙方的決策面:

  • A的決策。A會想,我如何才能獲得更大收益呢?
    • 先考慮最壞的情況:如果B坦白了,那麼我坦白就會判8年,我抗拒就會判十年,我應該坦白;
    • 再考慮最好的情況:如果B抗拒了,我坦白會判0年,我抗拒會判1年,我還是應該坦白;
    • 所以最終A會選擇坦白。
  • 同樣,B也會這樣想。

因此最終納什均衡點在兩個人都坦白,各判八年這裡。

顯然,集體最優解在兩個人都抗拒,這樣一來每個人都判一年就出來了。但是,納什均衡點卻不在這裡。而且,在納什均衡點上,任何一個人都沒有改變自己決策的動力。因為一旦單方面改變決策,那個人的收益就會下降。

0x3:開車加塞現象的納什均衡

我們知道,在國內開車夾塞很常見。如果大家都不夾塞,是整體的最優解,但是按照納什均衡理論,任何一個司機都會考慮,無論別人是否夾塞,我夾塞都可以使自己的收益變大。於是最終大家都會夾塞,加劇擁堵,反而不如大家都不加塞走的快。

那麼,有沒有辦法使個人最優變成集體最優呢?方法就是共謀。兩個小偷在作案之前可以說好,咱們如果進去了,一定都抗拒。如果你這一次敢反悔,那麼以後道上的人再也不會有人跟你一起了。也就是說,在多次博弈過程中,共謀是可能的。但是如果這個小偷想幹完這一票就走,共謀就是不牢靠的。

在社會領域,共謀是靠法律完成的。大家約定的共謀結論就是法律,如果有人不按照約定做,就會受到法律的懲罰。通過這種方式保證最終決策從個人最優的納什均衡點變為集體最優點。

另外一方面,現在很多汽車廠商提出了車聯網的概念,在路上的每一輛車都通過物聯網連成一個臨時網路,所有車按照一個最優的協同演算法共同協定最優的行車路線、行車速度、路口等待等行為,這樣整體交通可以達到一個整體最優,所有人都節省了時間。

0x3:槍手博弈

彼此痛恨的甲、乙、丙三個槍手準備決鬥,他們各自的水平如下:

  • 甲槍法最好,十發八中;
  • 乙槍法次之,十發六中;
  • 丙槍法最差,十發四中;

1. 場景一:三人同時開槍,並且每人只發一槍。每一輪槍戰後,誰活下來的機會大一些?

首先明確一點,這是一個靜態不完全資訊博弈,每個搶手在開槍前都不知道其他對手的策略,只能在猜測其他對手策略的基礎上,選擇對自己最優的策略。

我們來分析一下第一輪槍戰各個槍手的策略。

  • 槍手甲一定要對槍手乙先開槍。因為乙對甲的威脅要比丙對甲的威脅更大,甲應該首先幹掉乙,這是甲的最佳策略。
  • 同樣的道理,槍手乙的最佳策略是第一槍瞄準甲。乙一旦將甲幹掉,乙和丙進行對決,乙勝算的概率自然大很多。
  • 槍手丙的最佳策略也是先對甲開槍。乙的槍法畢竟比甲差一些,丙先把甲幹掉再與乙進行對決,丙的存活概率還是要高一些。

第一輪槍戰過後,有幾種可能的結果:

  • 甲乙雙亡,丙獲勝
  • 甲亡,乙丙存活
  • 乙亡,甲丙存活

現在進入第二輪槍戰:

除非第一輪甲乙雙亡,否則丙就一定處於劣勢,因為不論甲或乙,他們的命中率都比丙的命中率為高。

這就是槍手丙的悲哀。能力不行的丙玩些花樣雖然能在第一輪槍戰中暫時獲勝。但是,如果甲乙在第一輪槍戰中沒有雙亡的話,在第二輪槍戰結束後,丙的存活的機率就一定比甲或乙為低。

這似乎說明,能力差的人在競爭中耍弄手腕能贏一時,但最終往往不能成事。

2. 場景二:三人輪流開槍,沒人只發一槍。丙最後發槍。

我們現在改變遊戲規則,假定甲乙丙不是同時開槍,而是他們輪流開一槍。先假定開槍的順序是甲、乙、丙,我們來分析一下槍戰過程:

  • 甲一槍將乙幹掉後(80%的機率),就輪到丙開槍,丙有40%的機率一槍將甲幹掉。
  • 乙躲過甲的第一槍(20%機率),輪到乙開槍,乙還是會瞄準槍法最好的甲開槍,即使乙這一槍幹掉了甲(60%機率),下一輪仍然是輪到丙開槍(40%機率)。無論是甲或者乙先開槍,乙都有在下一輪先開槍的優勢。

如果是丙先開槍,情況又如何呢?

3. 場景三:三人輪流開槍,沒人只發一槍。丙第一個發槍。

  • 丙可以向甲先開槍(40%機率),
    • 即使丙打不中甲,甲的最佳策略仍然是向乙開槍。
    • 但是,如果丙打中了甲,下一輪可就是乙開槍打丙了。
  • 因此,丙的最佳策略是胡亂開一槍,只要丙不打中甲或者乙,在下一輪射擊中他就處於有利的形勢(先發優勢)。

我們通過這個例子,可以理解人們在博弈中能否獲勝,不單純取決於他們的實力,更重要的是取決於博弈方實力對比所形成的關係。

在上面的例子中,乙和丙實際上是一種聯盟關係,先把甲幹掉,他們的生存機率都上升了。我們現在來判斷一下,乙和丙之中,誰更有可能背叛,誰更可能忠誠?

任何一個聯盟的成員都會時刻權衡利弊,一旦背叛的好處大於忠誠的好處,聯盟就會破裂。在乙和丙的聯盟中,乙是最忠誠的。這不是因為乙本身具有更加忠誠的品質,而是利益關係使然。只要甲不死,乙的槍口就一定會瞄準甲。但丙就不是這樣了,丙不瞄準甲而胡亂開一槍顯然違背了聯盟關係,丙這樣做的結果,將使乙處於更危險的境地。

合作才能對抗強敵。只有乙丙合作,才能把甲先幹掉。如果,乙丙不和,乙或丙單獨對甲都不佔優,必然被甲先後解決。、

1966年經典電影《黃金三鏢客》中的最後一幕,三個主人公手持槍桿站在墓地中,為了寶藏隨時準備決一死戰。為了活著拿到寶藏,倖存下來的最優策略是什麼呢? 

0x4:蒙古聯合南宋滅金

當時,蒙古軍事實力最強,金國次之,南宋武力最弱。本來南宋應該和金國結盟,幫助金國抵禦蒙古的入侵才是上策,或者至少保持中立。但是,當時的南宋採取了和蒙古結盟的政策。南宋當局先是糊塗地同意了拖雷借道宋地伐金。1231年,蒙古軍隊在宋朝的先遣隊伍引導下,借道四川等地,北度漢水殲滅了金軍有生力量。

1233年,南宋軍隊與蒙古軍隊合圍蔡州,金朝最後一個皇帝在城破後死於亂兵,金至此滅亡。1279年,南宋正式亡於蒙古。

如果南宋當政者有戰略眼光,捐棄前嫌,與世仇金結盟對抗最強大的敵人蒙古,宋和金都不至於那麼快就先後滅亡了。

0x5:智豬博弈

豬圈裡面有兩隻豬, 一隻大,一隻小。豬圈很長,一頭有一個踏板,另一頭是飼料的出口和食槽。每踩一下踏板,在遠離踏板的豬圈的另一邊的投食口就會落下少量的食物。如果有一隻豬去踩踏板,另一隻豬就有機會搶先吃到另一邊落下的食物。

  • 當小豬踩動踏板時,大豬會在小豬跑到食槽之前剛好吃光所有的食物;
  • 若是大豬踩動了踏板,則還有機會在小豬吃完落下的食物之前跑到食槽,爭吃到另一半殘羹。

那麼,兩隻豬各會採取什麼策略?令人出乎意料的是,答案居然是:小豬將選擇“搭便車”策略,也就是舒舒服服地等在食槽邊;而大豬則為一點殘羹不知疲倦地奔忙於踏板和食槽之間。

原因何在呢?我們來分析一下,首先這是一個靜態不完全資訊博弈:

  • 小豬踩踏板:小豬將一無所獲,不踩踏板反而能吃上食物。對小豬而言,無論大豬是否踩動踏板,不踩踏板總是好的選擇。
  • 反觀大豬,已明知小豬是不會去踩動踏板的,自己親自去踩踏板總比不踩強吧,所以只好親力親為了。

“智豬博弈”的結論似乎是,在一個雙方公平、公正、合理和共享競爭環境中,有時佔優勢的一方最終得到的結果卻有悖於他的初始理性。這種情況在現實中比比皆是。

比如,在某種新產品剛上市,其效能和功用還不為人所熟識的情況下,如果進行新產品生產的不僅是一家小企業,還有其他生產能力和銷售能力更強的企業。那麼,小企業完全沒有必要作出頭鳥,自己去投入大量廣告做產品宣傳,只要採用跟隨戰略即可。

“智豬博弈”告訴我們,誰先去踩這個踏板,就會造福全體,但多勞卻並不一定多得。

在現實生活中,很多人都只想付出最小的代價,得到最大的回報,爭著做那隻坐享其成的小豬。“一個和尚挑水喝,兩個和尚擡水喝,三個和尚沒水喝”說的正是這樣一個道理。這三個和尚都想做“小豬”,卻不想付出勞動,不願承擔起“大豬”的義務,最後導致每個人都無法獲得利益。

0x6:證券市場中的“智豬博弈”

金融證券市場是一個群體博弈的場所,其真實情況非常複雜。在證券交易中,其結果不僅依賴於單個參與者自身的策略和市場條件,也依賴其他人的選擇及策略。

在“智豬博弈”的情景中,大豬是佔據比較優勢的,但是,由於小豬別無選擇,使得大豬為了自己能吃到食物,不得不辛勤忙碌,反而讓小豬搭了便車,而且比大豬還得意。這個博弈中的關鍵要素是豬圈的設計, 即踩踏板的成本。

證券投資中也是有這種情形的。例如,當莊家在底位買入大量股票後,已經付出了相當多的資金和時間成本,如果不等價格上升就撤退,就只有接受虧損。

所以,基於和大豬一樣的貪吃本能,只要大勢不是太糟糕,莊家一般都會擡高股價,以求實現手中股票的增值。這時的中小散戶,就可以對該股追加資金,當一隻聰明的“小豬”,而讓 “大豬”莊家力擡股價。當然,這種股票的發覺並不容易,所以當“小豬”所需要的條件,就是發現有這種情況存在的豬圈,並衝進去。這樣,你就成為一隻聰明的“小豬”。

股市中,散戶投資者與小豬的命運有相似之處,沒有能力承擔炒作成本,所以就應該充分利用資金靈活、成本低和不怕被套的優勢,發現並選擇那些機構投資者已經或可能坐莊的股票,等著大豬們為自己服務。

由此看到,散戶和機構的博弈中,散戶並不是總沒有優勢的,關鍵是找到有大豬的那個食槽,並等到對自己有利的遊戲規則形成時再進入。

0x7:納什均衡博弈與GAN網路的關係

GAN的主要靈感來源於博弈論中零和博弈的思想。

應用到深度學習神經網路上來說,就是通過生成網路G(Generator)和判別網路D(Discriminator)不斷博弈,進而使 G 學習到資料的分佈,同時時 D 獲得更好的魯棒性和泛化能力。

舉個例子:用在圖片生成上,我們想讓最後的 G 可以從一段隨機數中生成逼真的影象:

上圖中:

  • G是一個生成式的網路,它接收一個隨機的噪聲 z(隨機數),然後通過這個噪聲生成影象。

  • D是一個判別網路,判別一張圖片是不是 “真實的”。它的輸入是一張圖片,輸出的 D(x) 代表 x 為真實圖片的概率,如果為 1,就代表 100% 是真實的圖片,而輸出為 0,就代表不可能是真實的圖片。

那麼這個訓練的過程是什麼樣子的呢?在訓練中:

  • G 的目標就是儘量生成真實的圖片去欺騙判別網路 D。

  • D的目標就是儘量辨別出G生成的假影象和真實的影象。

這樣,G 和 D 就構成了一個動態的“博弈過程”,最終的平衡點即納什均衡點。

Relevant Link:     

https://baijiahao.baidu.com/s?id=1611846467821315306&wfr=spider&for=pc
https://www.jianshu.com/p/fadba906f5d3 

 

2. GAN網路的思想起源

GAN的起源之作鼻祖是 Ian Goodfellow 在 2014 年發表在 ICLR 的論文:Generative Adversarial Networks”。

按照筆者的理解,提出GAN網路的出發點有如下幾個:

  • 最核心的作用是提高分類器的魯棒能力,因為生成器不斷生成”儘量逼近真實樣本“的偽造影象,而分類器為了能正確區分出偽造和真實的樣本,就需要不斷地挖掘樣本中真正蘊含的潛在概率資訊,而拋棄無用的多餘特徵,這就起到了提高魯棒和泛化能力的作用。從某種程度上來說,GAN起到了和正則化約束的效果。
  • 基於隨機擾動,有針對性地生成新樣本。但是要注意的一點是,GAN生成的樣本並不是完全的未知新樣本,GAN的generator生成的新樣本更多的側重點是通過增加可控的擾動來嘗試躲避discriminator的檢測。實際上,GAN對生成0day樣本的能力很有限。

為了清楚地闡述這個概念,筆者先從對抗樣本這個話題開始說起。

0x1:對抗樣本(adversarial example)

對抗樣本(adversarial example)是指經過精心計算得到的用於誤導分類器的樣本。例如下圖就是一個例子,左邊是一個熊貓,但是添加了少量隨機噪聲變成右圖後,分類器給出的預測類別卻是長臂猿,但視覺上左右兩幅圖片並沒有太大改變。

出現這種情況的原因是什麼呢?

簡單來說,就是預測器發生了過擬合。影象分類器本質上是高維空間的一個複雜的決策函式,在高維空間上,影象分類器過分考慮了全畫素區間內的細節資訊,導致預測器對影象的細節資訊太敏感,微小的擾動就可能導致預測器的預測行為產生很大的變化。

關於這個話題,筆者在另一篇文章中對過擬合現象以及規避方法進行了詳細討論。

除了新增”隨機噪聲驅動的畫素擾動”這種方法之外,還可以通過影象變形的方式,使得新影象和原始影象視覺上一樣的情況下,讓分類器得到有很高置信度的錯誤分類結果。這種過程也被稱為對抗攻擊(adversarial attack)。

0x2:有監督驅動的無監督學習

人類通過觀察和體驗物理世界來學習,我們的大腦十分擅長預測,不需要顯式地經過複雜計算就可以得到正確的答案。監督學習的過程就是學習資料和標籤之間的相關關係。

但是在非監督學習中,資料並沒有被標記,而且目標通常也不是對新資料進行預測。

在現實世界中,標記資料是十分稀有和昂貴的。生成對抗網路通過生成偽造的/合成的資料並嘗試判斷生成樣本真偽的方法學習,這本質上相當於採用了監督學習的方法來做無監督學習。做分類任務的判別器在這裡是一個監督學習的元件,生成器的目標是瞭解真實資料的模樣(概率分佈),並根據學到的知識生成新的資料。

Relevant Link:  

https://www.jiqizhixin.com/articles/2018-03-05-4

 

3. GAN網路基本原理

GAN網路發展到如今已經有很多的變種,在arxiv上每天都會有大量的新的研究論文被提出。但是筆者這裡不準備列舉所有的網路結構,而是僅僅討論GAN中最核心的思想,通過筆者自己的論文閱讀,將我認為最精彩的思想和學術創新提煉出來給大家,今後我們也可以根據自己的理解,將其他領域的思想交叉引入進來,繼續不斷創新發展。

0x1:GAN的組成

 

經典的GAN網路由兩部分組成,分別稱之為判別器D和生成器G,兩個網路的工作原理可以如下圖所示,

D 的目標就是判別真實圖片和 G 生成的圖片的真假,而 G 是輸入一個隨機噪聲來生成圖片,並努力欺騙 D。

簡單來說,GAN 的基本思想就是一個最小最大定理,當兩個玩家(D 和 G)彼此競爭時(零和博弈),雙方都假設對方採取最優的步驟而自己也以最優的策略應對(最小最大策略),那麼結果就會進入一個確定的均衡狀態(納什均衡)。

0x2:損失函式分析

1. 生成器(generator)損失函式

生成器網路以隨機的噪聲z作為輸入並試圖生成樣本資料,並將生成的偽造樣本資料提供給判別器網路D,

可以看到,G 網路的訓練目標就是讓 D(G(z)) 趨近於 1,即完全騙過判別器(判別器將生成器生成的偽造樣本全部誤判為真)。G 網路通過接受 D 網路的反饋作為梯度改進方向,通過BP過程反向調整自己的網路結構引數。

2. 判別器(discriminator)

判別器網路以真實資料x或者偽造資料G(z)作為輸入,並試圖預測當前輸入是真實資料還是生成的偽造資料,併產生一個【0,1】範圍內的預測標量值。

D 網路的訓練目標是區分真假資料,D 網路的訓練目標是讓 D(x) 趨近於 1(真實的樣本判真),而 D(G(z)) 趨近於0(偽造的樣本判黑)。D 網路同時接受真實樣本和 G 網路傳入的偽造樣本作為梯度改進方向,,通過BP過程反向調整自己的網路結構引數。

3. 綜合損失函式

生成器和判別器網路的損失函式結合起來就是生成對抗網路(GAN)的綜合損失函式:

兩個網路相互對抗,彼此博弈,如上所示,綜合損失函式是一個極大極小函式;

  • 損失函式第一項:會驅使判別器儘量將真實樣本都判真
  • 損失函式第二項:會驅使判別器儘量將偽造樣本都判黑。但同時,生成器G會對抗這個過程

整個相互對抗的過程,Ian Goodfellow 在論文中用下圖來描述:

 

 黑色曲線表示輸入資料 x 的實際分佈,綠色曲線表示的是 G 網路生成資料的分佈,紫色的曲線表示的是生成資料對應於 D 的分佈的差異距離(KL散度)

GAN網路訓練的目標是希望著實際分佈曲線x,和G網路生成的資料,兩條曲線可以相互重合,也就是兩個資料分佈一致(達到納什均衡)。

  • a圖:網路剛開始訓練,D 的分類能力還不是最好,因此有所波動,而生成資料的分佈也自然和真實資料分佈不同,畢竟 G 網路輸入是隨機生成的噪聲;
  • b圖:隨著訓練的進行,D 網路的分類能力就比較好了,可以看到對於真實資料和生成資料,它是明顯可以區分出來,也就是給出的概率是不同的;
  • c圖:由於 D 網路先行提高的效能,隨後 G 網路開始追趕,G 網路的目標是學習真實資料的分佈,即綠色的曲線,所以它會往藍色曲線方向移動。因為 G 和 D 是相互對抗的,當 G 網路提升,也會影響 D 網路的分辨能力;
  • d圖:當假設 G 網路不變(G已經優化到收斂狀態),繼續訓練 D 網路,最優的情況會是,也就是當生成資料的分佈趨近於真實資料分佈的時候,D 網路輸出的概率會趨近於 0.5(真實樣本和偽造樣本各佔一半,生成器無法再偽造了,判別器也無法再優化了,也可以說對於判別器來說其無法從樣本中區分中真實樣本和偽造樣本),這也是最終希望達到的訓練結果,這時候 G 和 D 網路也就達到一個平衡狀態。

0x3:演算法偽碼流程

論文給出的演算法實現過程如下所示:

一些細節需要注意:

  • 首先 G 和 D 是同步訓練,但兩者訓練次數不一樣,通常是 D 網路訓練 k 次後,G 訓練一次。主要原因是 GAN 剛開始訓練時候會很不穩定,需要讓判別器D儘快先進入收斂區間;
  • D 的訓練是同時輸入真實資料和生成資料來計算 loss,而不是採用交叉熵(cross entropy)分開計算。不採用 cross entropy 的原因是這會讓 D(G(z)) 變為 0,導致沒有梯度提供給 G 更新,而現在 GAN 的做法是會收斂到 0.5;
  • 實際訓練的時候,作者是採用來代替,這是希望在訓練初始就可以加大梯度資訊,這是因為初始階段 D 的分類能力會遠大於 G 生成足夠真實資料的能力,但這種修改也將讓整個 GAN 不再是一個完美的零和博弈。

0x4:演算法的優點

GAN的巧妙之處在於其目標函式的設定,因為此,GAN有如下幾個優點:

  • GAN 中的 G 作為生成模型,不需要像傳統圖模型一樣,需要一個嚴格的生成資料的概率表示式。這就避免了當資料非常複雜的時候,複雜度過度增長導致的不可計算。
  • GAN 不需要 inference 模型中的一些龐大計算量的求和計算。它唯一的需要的就是,一個噪音輸入,一堆無標準的真實資料,兩個可以逼近函式的網路。

0x5:演算法的挑戰與缺陷

初代GAN有一些缺點,或者是說挑戰,筆者這裡介紹如下:
  • 啟動及初始化的問題:GAN的訓練目標是讓生成器和判別器最終達到一個納什均衡狀態,此時兩個網路都無法繼續再往前做任何優化,優化結束。梯度下降的啟動會選擇一個減小所定義問題損失的方法,但是並沒有理論保證GAN一定可以100%進入納什均衡狀態,這是一個高維度的非凸優化目標。網路試圖在接下來的步驟中最小化非凸優化目標,但是最終可能導致進入震盪而不是收斂到底層真實目標。
  • GAN 過於自由導致訓練難以收斂以及不穩定。
  • 梯度消失問題:原始 G 的損失函式沒有意義,它是讓 G 最小化 D 識別出自己生成的假樣本的概率,但實際上它會導致梯度消失問題,這是由於開始訓練的時候,G 生成的圖片非常糟糕,D 可以輕而易舉的識別出來,這樣 D 的訓練沒有任何損失,也就沒有有效的梯度資訊回傳給 G 去優化它自己,這就是梯度消失了。最後,雖然作者意識到這個問題,在實際應用中改用來代替,這相當於從最小化 D 揪出自己的概率,變成了最大化 D 抓不到自己的概率。雖然直觀上感覺是一致的,但其實並不在理論上等價,也更沒有了理論保證在這樣的替代目標函式訓練下,GAN 還會達到平衡。這個結果會進一步導致模式奔潰問題。
  • 模型坍塌:基本原理是生成器可能會在某種情況下重複生成完全一致的影象(也可以理解為梯度消失),這其中的原因和博弈論中的啟動問題相關。我們可以這樣來想象GAN的訓練過程,
    • 先從判別器的角度試圖最大化,再從生成器的角度試圖最小化。如果生成器最小化開始之前,判別器已經完全最大化,所有工作還可以正常執行;
    • 如果首先最小化生成器,再從判別器的角度試圖最大化。如果判別器最大化開始之前,生成器已經完全最小化,那麼工作就無法執行。原因在於如果我們保持判別器不變,它會將空間中的某些點標記為最有可能是真的而不是假的(因為生成器已經最小化了),這樣生成器就會選擇將所有的噪聲輸入對映到那些最可能為真的點上,這就陷入了局部最優的陷阱中了,優化過程就提前停止了。
當然上面提到的很多缺點已經在後續的學術論文中被新提出的修改演算法解決了,我們接下來討論其主要解決思想。

0x6:提升GAN訓練效果的一些方法

1. 中間層特徵驅動損失函式

針對GAN不穩定的問題,學者們提出了通過使用判別器中間層的特徵來預測影象,並將結果作為監督資訊來反饋給生成器。 通過這種方式,訓練得到的生成器的生成資料會匹配真實資料的統計特性以及判別器中間層的預期特徵值。這樣強迫判別器去尋找那些最能很好地判別真實資料的潛在特徵,而不是那些由當前模型生成資料的表層特徵。

2. 小批量度量輸入樣本相似度

模型坍塌的問題可以通過引入額外的度量特徵來解決(例如KL散度)。這樣判別器每次收到的是一小批樣本而不是一個單獨樣本,判別器可以使用例如KL散度來度量樣本之間的距離,這樣就很容易檢測出當前的生成器是不是已經開始坍塌。從而阻止了生成器繼續向區域性最大似然點滑落。 總體來說,小批量樣本表現更接近實際,而且可以保證不同樣本之間在空間上有合適的距離。

3. 引入歷史平均

歷史平均的思想是加入一個懲罰項來懲罰那些和歷史平均權重相差過多的權重值。即如果當前引數值和歷史上最近t批該引數平均值的距離越近,給予的懲罰越大。 通過這種方式,可以緩解目標函式在收斂後期的震盪。

4. 單側標籤平滑

通常情況下我們使用標籤0代表真實影象,使用1代表偽造影象。我們還可以使用一些更平滑的標籤,例如0.1和0.9,它們可以使得網路在一些對抗的例子中更加健壯。

5. 輸入規範化

使用tanh作為生成器最後一層啟用函式,可以獲得更平滑的收斂效果。

6. 批規範化

在每一個批次的資料中標準化前一層的啟用項, 即,應用一個維持啟用項平均值接近 0,標準差接近 1 的轉換。

7. 利用ReLU和MaxPool避免梯度稀疏

如果梯度稀疏,GAN博弈的穩定性會受到很大影響,Leaky ReLU對生成和判別器的梯度稀疏問題都會有緩解作用。

Relevant Link:  

https://arxiv.org/pdf/1406.2661.pdf 
https://juejin.im/post/5bdd70886fb9a049f912028d 
http://www.iterate.site/2018/07/27/gan-%E7%94%9F%E6%88%90%E5%AF%B9%E6%8A%97%E7%BD%91%E7%BB%9C%E4%BB%8B%E7%BB%8D/

 

4. 從生成模型和判別模型的概率視角看GAN

在閱讀了很多GAN衍生論文以及GAN原始論文之後,筆者一直在思考的一個問題是:GAN背後的底層思想是什麼?GAN衍生和改進演算法的靈感和思路又是從哪裡來的?

經過一段時間思考以及和同行同學討論後,我得出了一些思考,這裡分享如下,希望對讀者朋友有幫助。

我們先來看什麼是判別模型和生成模型:

  • 判別式模型學習某種分佈下的條件概率p(y|x),即在特定x條件下y發生的概率。判別器模型十分依賴資料的質量,概率分佈p(y|x)可以直接將一個特定的x分類到某個標籤y上。以邏輯迴歸為例,我們所需要做的是最小化損失函式。
  • 生成式模型學習的是聯合分佈概率p(x,y),x是輸入資料,y是所期望的分類。一個生成模型可以根據當前資料的假設生成更多新樣本。

從概率論的視角來看,我們來看一下原始GAN網路的架構:

  • 生成器本質上是一個由輸入向量和生成器結構所代表的向量組成的聯合概率分佈P(v_input, v_G_structure)
    • v_input:代表一種輸入向量,可以是隨機噪聲向量z
    • v_G_structure:網路本質上是對輸入向量進行線性和非線性變化,因為可以將其抽象為一個動態變化的向量函式
  • 判別器本質上是一個由(真實樣本,偽造樣本)作為輸入x,進行後驗預測p(y|x)的概率模型

遵循這種框架進行思考,CGAN只是將v_input中的隨機噪聲z替換成了另一種向量(文字或者標籤向量),而Pix2pixGAN是將一個影象向量作為v_input輸入GAN網路。

 

5. 從原始GAN網路中衍生出的流行GAN架構

GAN的發展離不開goodfellow後來的學者們不斷的研究與發展,目前已經提出了很多優秀的新GAN架構,並且這個發展還在繼續。為了讓本博文能保持一定的環境獨立性,筆者這裡不做完整的羅列與列舉,相反,筆者希望從兩條脈絡來展開討論:

  • 解決問題導向:為了解決原始GAN或者當前學術研究中發現的關於GAN網路的效能和架構問題而提出的新理論與新框架
  • 新場景應用導向:為了將GAN應用在新的領域中而提出的新的GAN架構

0x1:DCGAN(Deep Convolutional Generative Adversarial Networks)

Alec Radford,Luke Metz,Soumith Chintala等人在“Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks”提出了DCGAN。這是GAN研究的一個重要里程碑,因為它提出了一個重要的架構變化來解決訓練不穩定,模式崩潰和內部協變數轉換等問題。從那時起,基於DCGAN的架構就被應用到了許多GAN架構。

DCGAN的提出主要是為了解決原始GAN架構的原生架構問題,我們接下來來討論下。

1. 生成器的架構優化

生成器從潛在空間中得到100維噪聲向量z,通過一系列卷積和上取樣操作,將z對映到一個畫素矩陣對應的空間中,如下圖:

DCGAN通過下面的一些架構性約束來固化網路: 

  • 在判別器中使用步數卷積來取代池化層,在生成器中使用小步數卷積來取代池化層;
  • 在生成器和判別器中均使用批規範化,批規範化是一種通過零均值和單位方差的方法進行輸入規範化使得學習過程固話的技術。這項技術在實踐中被證實可以在許多場合提升訓練速度,減少初始化不佳帶來的啟動問題,並且通常能產生更準確的結果;
  • 消除原架構中較深的全連線隱藏層,並且在最後只使用簡單的平均值池化;
  • 在生成器輸出層使用tanh,在其它層均使用ReLU激發;
  • 在判別器的所有層中都使用Leaky ReLU激發;

2. 模型訓練

生成器和判別器都是通過binary_crossentropy作為損失函式來進行訓練的。之後的每個階段,生成器產生一個MNIST影象,判別器嘗試在真實MNIST影象和生成影象的資料集中進行學習。

經過一段時間後,生成器就可以自動學會如何製作偽造的數字。

from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np

class DCGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(
            loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy']
        )

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        valid = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        model = Sequential()

        model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
        model.add(Reshape((7, 7, 128)))
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
        model.add(Activation("tanh"))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        model = Sequential()

        model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(ZeroPadding2D(padding=((0,1),(0,1))))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Flatten())
        model.add(Dense(1, activation='sigmoid'))

        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, save_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            # Sample noise and generate a batch of new images
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator (real classified as ones and generated as zeros)
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Train the generator (wants discriminator to mistake images as real)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.save_imgs(epoch)

    def save_imgs(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    dcgan = DCGAN()
    dcgan.train(epochs=4000, batch_size=32, save_interval=50)

DCGAN產生的手寫數字輸出

0x2:CGAN(Conditional GAN,CGAN)

1. 有輸入條件約束的生成器網路架構 

CGAN由Mehdi Mirza,Simon Osindero在論文“Conditional Generative Adversarial Nets”中首次提出。

在條件GAN中,生成器並不是從一個隨機的噪聲分佈中開始學習,而是通過一個特定的條件或某些特徵(例如一個影象標籤或者一些文字資訊)開始學習如何生成偽造樣本。

 

在CGAN中,生成器和判別器的輸入都會增加一些條件變數y,這樣判別器D(x,y)和生成器G(z,y)都有了一組聯合條件變數。

我們將CGAN的目標函式和GAN進行對比會發現:

 

 GAN目標函式

 

CGAN目標函式

GAN和CGAN的損失函式區別在於判別器和生成器多出來一個引數y,架構上,CGAN相比於GAN增加了一個輸入層條件向量C,同時連線了判別器和生成器網路。

 

2. 訓練過程

在訓練過程,我們將y輸入給生成器和判別器網路。 

from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import numpy as np

class CGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.num_classes = 10
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(
            loss=['binary_crossentropy'],
            optimizer=optimizer,
            metrics=['accuracy']
        )

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise and the target label as input
        # and generates the corresponding digit of that label
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,))
        img = self.generator([noise, label])

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated image as input and determines validity
        # and the label of that image
        valid = self.discriminator([img, label])

        # The combined model  (stacked generator and discriminator)
        # Trains generator to fool discriminator
        self.combined = Model([noise, label], valid)
        self.combined.compile(loss=['binary_crossentropy'],
            optimizer=optimizer)

    def build_generator(self):

        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))

        model_input = multiply([noise, label_embedding])
        img = model(model_input)

        return Model([noise, label], img)

    def build_discriminator(self):

        model = Sequential()

        model.add(Dense(512, input_dim=np.prod(self.img_shape)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        label = Input(shape=(1,), dtype='int32')

        label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))
        flat_img = Flatten()(img)

        model_input = multiply([flat_img, label_embedding])

        validity = model(model_input)

        return Model([img, label], validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, y_train), (_, _) = mnist.load_data()

        # Configure input
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)
        y_train = y_train.reshape(-1, 1)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs, labels = X_train[idx], y_train[idx]

            # Sample noise as generator input
            noise = np.random.normal(0, 1, (batch_size, 100))

            # Generate a half batch of new images
            gen_imgs = self.generator.predict([noise, labels])

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
            d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Condition on labels
            sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)

            # Train the generator
            g_loss = self.combined.train_on_batch([noise, sampled_labels], valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 2, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        sampled_labels = np.arange(0, 10).reshape(-1, 1)

        gen_imgs = self.generator.predict([noise, sampled_labels])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
                axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    cgan = CGAN()
    cgan.train(epochs=20000, batch_size=32, sample_interval=200)

根據輸入數字生成對應的MNIST手寫數字影象

0x3:CycleGAN(Cycle Consistent GAN,迴圈一致生成網路)

CycleGANs 由Jun-Yan Zhu,Taesung Park,Phillip Isola和Alexei A. Efros在題為“Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks”的論文中提出

CycleGAN用來實現不需要其他額外資訊,就能將一張影象從源領域對映到目標領域的方法,例如將照片轉換為繪畫,將夏季拍攝的照片轉換為冬季拍攝的照片,或將馬的照片轉換為斑馬照片,或者相反。總結來說,CycleGAN常備用於不同的影象到影象翻譯。

 

1. 迴圈網路架構

CycleGAN背後的核心思想是兩個轉換器F和G,其中:

  • F會將影象從域A轉換到域B;
  • G會將影象從域B轉換到域A;

因此,

  • 對於一個在域A的影象x,我們期望函式G(F(x))的結果與x相同,即 x == G(F(x));
  • 對於一個在域B的影象y,我們期望函式F(G(y))的結果與y相同,即 y == F(G(y));

和原始的GAN結構相比,由單個G->D的單向開放結構,變成了由兩對G<->D組成的雙向迴圈的封閉結構,但形式上依然是G給D輸入偽造樣本。但區別在於梯度的反饋是雙向迴圈的。

2. 損失函式

CycleGAN模型有以下兩個損失函式:

  • 對抗損失(Adversarial Loss):判別器和生成器之間互相對抗的損失,這就是原始GAN網路的損失函式公式:
  • 迴圈一致損失(Cycle Consistency Loss):綜合權衡轉換器F和G的損失,F和G之間是編碼與解碼的對抗關係,不可能同時取到最小值,只能得到整體的平衡最優值:

完整的CycleGAN目標函式如下:

from __future__ import print_function, division
import scipy

from keras.datasets import mnist
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import os

class CycleGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 128
        self.img_cols = 128
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        # Configure data loader
        self.dataset_name = 'horse2zebra'
        self.data_loader = DataLoader(
            dataset_name=self.dataset_name,
            img_res=(self.img_rows, self.img_cols)
        )

        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)

        # Number of filters in the first layer of G and D
        self.gf = 32
        self.df = 64

        # Loss weights
        self.lambda_cycle = 10.0                    # Cycle-consistency loss
        self.lambda_id = 0.1 * self.lambda_cycle    # Identity loss

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminators
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()
        self.d_A.compile(
            loss='mse',
            optimizer=optimizer,
            metrics=['accuracy']
        )
        self.d_B.compile(
            loss='mse',
            optimizer=optimizer,
            metrics=['accuracy']
        )

        # -------------------------
        # Construct Computational
        #   Graph of Generators
        # -------------------------

        # Build the generators
        self.g_AB = self.build_generator()
        self.g_BA = self.build_generator()

        # Input images from both domains
        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)

        # Translate images to the other domain
        fake_B = self.g_AB(img_A)
        fake_A = self.g_BA(img_B)
        # Translate images back to original domain
        reconstr_A = self.g_BA(fake_B)
        reconstr_B = self.g_AB(fake_A)
        # Identity mapping of images
        img_A_id = self.g_BA(img_A)
        img_B_id = self.g_AB(img_B)

        # For the combined model we will only train the generators
        self.d_A.trainable = False
        self.d_B.trainable = False

        # Discriminators determines validity of translated images
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)

        # Combined model trains generators to fool discriminators
        self.combined = Model(
            inputs=[img_A, img_B],
            outputs=[valid_A, valid_B, reconstr_A, reconstr_B, img_A_id, img_B_id ]
        )
        self.combined.compile(
            loss=['mse', 'mse', 'mae', 'mae', 'mae', 'mae'],
            loss_weights=[1, 1, self.lambda_cycle, self.lambda_cycle, self.lambda_id, self.lambda_id],
            optimizer=optimizer
        )

    def build_generator(self):
        """U-Net Generator"""

        def conv2d(layer_input, filters, f_size=4):
            """Layers used during downsampling"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            d = InstanceNormalization()(d)
            return d

        def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = InstanceNormalization()(u)
            u = Concatenate()([u, skip_input])
            return u

        # Image input
        d0 = Input(shape=self.img_shape)

        # Downsampling
        d1 = conv2d(d0, self.gf)
        d2 = conv2d(d1, self.gf*2)
        d3 = conv2d(d2, self.gf*4)
        d4 = conv2d(d3, self.gf*8)

        # Upsampling
        u1 = deconv2d(d4, d3, self.gf*4)
        u2 = deconv2d(u1, d2, self.gf*2)
        u3 = deconv2d(u2, d1, self.gf)

        u4 = UpSampling2D(size=2)(u3)
        output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)

        return Model(d0, output_img)

    def build_discriminator(self):

        def d_layer(layer_input, filters, f_size=4, normalization=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if normalization:
                d = InstanceNormalization()(d)
            return d

        img = Input(shape=self.img_shape)

        d1 = d_layer(img, self.df, normalization=False)
        d2 = d_layer(d1, self.df*2)
        d3 = d_layer(d2, self.df*4)
        d4 = d_layer(d3, self.df*8)

        validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

        return Model(img, validity)

    def train(self, epochs, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()

        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)

        for epoch in range(epochs):
            for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):

                # ----------------------
                #  Train Discriminators
                # ----------------------

                # Translate images to opposite domain
                fake_B = self.g_AB.predict(imgs_A)
                fake_A = self.g_BA.predict(imgs_B)

                # Train the discriminators (original images = real / translated = Fake)
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

                # Total disciminator loss
                d_loss = 0.5 * np.add(dA_loss, dB_loss)


                # ------------------
                #  Train Generators
                # ------------------

                # Train the generators
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                                        [valid, valid,
                                                        imgs_A, imgs_B,
                                                        imgs_A, imgs_B])

                elapsed_time = datetime.datetime.now() - start_time

                # Plot the progress
                print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                                                                        % ( epoch, epochs,
                                                                            batch_i, self.data_loader.n_batches,
                                                                            d_loss[0], 100*d_loss[1],
                                                                            g_loss[0],
                                                                            np.mean(g_loss[1:3]),
                                                                            np.mean(g_loss[3:5]),
                                                                            np.mean(g_loss[5:6]),
                                                                            elapsed_time))

                # If at save interval => save generated image samples
                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)

    def sample_images(self, epoch, batch_i):
        if not os.path.exists('images/%s' % self.dataset_name):
            os.makedirs('images/%s' % self.dataset_name)
        r, c = 2, 3

        imgs_A = self.data_loader.load_data(domain="A", batch_size=1, is_testing=True)
        imgs_B = self.data_loader.load_data(domain="B", batch_size=1, is_testing=True)

        # Demo (for GIF)
        #imgs_A = self.data_loader.load_img('datasets/apple2orange/testA/n07740461_1541.jpg')
        #imgs_B = self.data_loader.load_img('datasets/apple2orange/testB/n07749192_4241.jpg')

        # Translate images to the other domain
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)
        # Translate back to original domain
        reconstr_A = self.g_BA.predict(fake_B)
        reconstr_B = self.g_AB.predict(fake_A)

        gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
        plt.close()


if __name__ == '__main__':
    gan = CycleGAN()
    gan.train(epochs=200, batch_size=1, sample_interval=200)

蘋果->橙子->蘋果 

有類似架構思想的還有DiscoGAN,相關論文可以在axiv上找到。

0x4:StackGAN

StackJANs由Han Zhang,Tao Xu,Hongsheng Li還有其他人在題為“StackGAN: Text to Photo-Realistic Image Synthesis with Stacked Generative Adversarial Networks”的論文中提出。他們使用StackGAN來探索文字到影象的合成,得到了非常好的結果。

一個StackGAN由一對網路組成,當提供文字描述時,可以生成逼真的影象。

0x5:Pix2pix

pix2pix網路由Phillip Isola,Jun-Yan Zhu,Tinghui Zhou和Alexei A. Efros在他們的題為“Image-to-Image Translation with Conditional Adversarial Networks”的論文中提出。

對於影象到影象的翻譯任務,pix2pix也顯示出了令人印象深刻的結果。無論是將夜間影象轉換為白天的影象還是給黑白影象著色,或者將草圖轉換為逼真的照片等等,Pix2pix在這些例子中都表現非常出色。

0x6:Age-cGAN(Age Conditional Generative Adversarial Netw