1. 程式人生 > >[TensorFlow]生成對抗網路(GAN)介紹與實踐

[TensorFlow]生成對抗網路(GAN)介紹與實踐

主旨

本文簡要介紹了生成對抗網路(GAN)的原理,接下來通過tensorflow開發程式實現生成對抗網路(GAN),並且通過實現的GAN完成對等差數列的生成和識別。通過對設計思路和實現方案的介紹,本文可以輔助讀者理解GAN的工作原理,並掌握實現方法。有了這樣的基礎,在面對工作中實際問題時可以將GAN納入考慮,選擇最合適的演算法

程式碼和執行環境

TensorFlow版本

>>> tf.version 
‘1.1.0-rc2’

背景知識

Generative Adversarial Nets[1][https://arxiv.org/pdf/1406.2661v1.pdf]是Ian J. Goodfellow等在2014年提出的一種訓練模型的方法,此方法通過兩個網路(生成網路G和分類網路D)對抗訓練,得到符合預期目標的生成模型和分類模型。

要理解GAN的原理,上述論文是最好的教材。但考慮到原文首先是英文撰寫,其次包含不少數學推導,新手上手並不容易。因此筆者這裡班門弄斧,基於論文簡單轉述GAN的設計思想要點

GAN的目標,給定一個真實樣本(本文也稱之為ground truth)集合,訓練出兩個模型,一個能夠從噪聲訊號生成儘可能像ground truth的樣本;另一個能夠判斷給定樣本是否是ground truth。兩個模型詳細介紹如下

  • 生成模型:論文中稱為generative model,本文稱為G網路或G模型。G網路的輸入是噪聲訊號(例如均勻分佈的隨機數),輸出為形狀與真實樣本ground truth一致。G網路的訓練目標是,儘可能輸出與ground truth相似的樣本。這裡“相似”定義為:如果G網路生成的一個樣本騙過了D網路,使得D網路誤以為這就是真實樣本,則就是相似的,G網路獲得獎勵;反之,獲得懲罰。
  • 分類模型:論文中稱為discriminative model,本文稱為D網路或D模型。D網路是一個2分類器,輸入為ground truth或者G網路生成的樣本,輸出為TRUE或FALSE:TRUE表示D網路認為當前輸入樣本是ground truth,FALSE表示D網路認為當前輸入樣本是G網路生成的“偽造”樣本。D網路的訓練目標是儘可能正確的區分開ground truth和G網路生成的“偽造”樣本。

從上述討論可以看出,G網路和D網路是兩個目標完全相反的網路,G網路盡其所能“偽造”出像真實樣本的資料,D網路儘可能區分真實與偽造資料。GAN中所謂“對抗”的概念,即來源於此。

GAN的訓練過程就是G和D兩個網路互相對抗的過程,對抗的結果是G網路被訓練到能夠生成以假亂真的樣本,即G網路從噪聲輸入得到了儘可能與真實樣本相似的輸出,或者說G學會了從噪聲生成ground truth的方法;D網路也可以區分ground truth與其他樣本,即D學會了區分ground truth與其他資料的方法。

參考文獻 
1. Goodfellow I J, Pougetabadie J, Mirza M, et al. Generative adversarial nets[C]. neural information processing systems, 2014: 2672-2680.

神經網路設計和實現

問題構造

在開始設計神經網路之前,我們首先構造出預期GAN解決的問題。前述GAN論文中提出了一個從噪聲學習正態分佈的經典問題,讀者如果在網路上搜索GAN的案例,除了影象識別,基本上只有這麼一個問題和方案實現。

本文重新設計了一個與論文中不同的問題。問題描述如下

  • Ground Truth定義:[1,2,3,4,5,6,7,8,9,10]構成的等差數列,為了適當降低學習難度,此數列每個元素與噪聲相加,噪聲為0均值正態分佈隨機變數,標準差取0.1, 0.03, 0等不同數值
  • 輸入噪聲定義: [-1,1]之間均勻分佈的隨機變數。

網路結構設計

G網路:參考論文資料,我們選擇多層全連線神經網路

D網路:由於要分辨的是等差數列,我們選擇RNN作為D網路。

網路結構如下(下圖是tensorboard生成的計算圖):圖中”G_net”表示G網路,”D_net”/”D_net_1”表示D網路,雖然圖中D網路被分成了兩份,但是其RNN引數是共享的,即圖中正下方”rnn”這個單元。

程式碼實現

G網路定義

   # generative network
    # use multi-layer percepton to generate time sequence from random noise
    # input tensor must be in shape of (batch_size, self.seq_len)
    def generator(self, inputTensor):
        with tf.name_scope('G_net'):
            gInputTensor = tf.identity(inputTensor, name='input')
            # Multilayer percepton implementation
            numNodesInEachLayer = 10
            numLayers = 3 

            previous_output_tensor = gInputTensor
            for layerIdx in range(numLayers):
                activation,z = self.fullConnectedLayer(previous_output_tensor, numNodesInEachLayer, layerIdx)
                previous_output_tensor = activation

            g_logit = z
            g_logit = tf.identity(g_logit, 'g_logit')
            return g_logit

G網路損失函式 
下面程式碼片段中self.d_logit_fake是D網路對G網路生成資料的判定結果。由於G網路的目標是儘可能騙過D網路,如果D網路對於G網路生成資料全部判為1(即TRUE),則損失最小,反之,損失最大。

g_loss_d = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=self.d_logit_fake,
                    labels=tf.ones(shape=[self.batch_size_t,1])
                    ),
                name='g_loss_d'
                )

D網路的定義 
RNN+全連線輸出層,無論是RNN還是全連線層都必須在對ground truth和G生成樣本之間共享同一套引數

def discriminator(self, inputTensor,reuseParam):
        with tf.name_scope('D_net'):
            num_units_in_LSTMCell = 10

            # RNN definition
            with tf.variable_scope('d_rnn'):
                lstmCell = tf.contrib.rnn.BasicLSTMCell(num_units_in_LSTMCell,reuse=reuseParam)
                init_state = lstmCell.zero_state(self.batch_size_t, dtype=tf.float32)
                raw_output, final_state = tf.nn.dynamic_rnn(lstmCell, inputTensor, initial_state=init_state)

            rnn_output_list = tf.unstack(tf.transpose(raw_output, [1, 0, 2]), name='outList')
            rnn_output_tensor = rnn_output_list[-1];

            # Full connected network
            numberOfInputDims = inputTensor.shape[1].value
            numOfNodesInLayer = 1
            if not reuseParam:
                self.d_w = tf.Variable(initial_value=tf.random_normal([numberOfInputDims, numOfNodesInLayer]),
                        name=('dnet_w_1'))
                self.d_b = tf.Variable(tf.zeros([1, numOfNodesInLayer]), name='dnet_b_1')
            self.d_z = tf.matmul(rnn_output_tensor,self.d_w) + self.d_b
            self.d_z = tf.identity(self.d_z, name='dnet_z_1')
            d_sigmoid = tf.nn.sigmoid(self.d_z, name='dnet_a_1')

            d_logit = self.d_z
            d_logit = tf.identity(d_logit, 'd_net_logit')
            return d_logit

D網路損失函式 
D網路使用同一套引數分辨兩種輸入,一種是ground truth,另一種是G網路的輸出。對於ground truth,訓練目標為儘可能判為1,對於G網路的輸出,訓練目標為儘可能判為0,因此Loss函式定義如下

# For D-network, jduge ground truth to TRUE, jduge G-network output to FALSE,making loss low
            d_loss_ground_truth = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=self.d_logit_gnd_truth,
                    labels=tf.ones(shape=[self.batch_size_t,1])
                    ),
                name='d_loss_gnd'
                )

            d_loss_fake = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=self.d_logit_fake,
                    labels=tf.zeros(shape=[self.batch_size_t,1])
                    ),
                name='d_loss_fake'
                )

            d_loss = d_loss_ground_truth + d_loss_fake

對抗訓練 
對抗訓練中,G網路Loss值只用來調整G網路引數,D網路Loss值只用來調整D網路引數

       g_net_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G_net')
        g_net_var_list = g_net_var_list +  tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='g_rnn')
        self.train_g = tf.train.AdamOptimizer(self.lr_g).minimize(g_loss,var_list=g_net_var_list)

        d_net_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='D_net')
        d_net_var_list = d_net_var_list +  tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='d_rnn')
        self.train_d = tf.train.AdamOptimizer(self.lr_d).minimize(d_loss,var_list=d_net_var_list)

訓練效果

下圖是訓練過程中D網路對ground truth和G網路輸出的分類正確率曲線 

從圖中可以看到3個階段

  1. 訓練開始後一秒左右:“D網路對ground truth的分類正確率”和“D網路對G網路輸出的分類正確率”都快速上升到100%,即D網路經過訓練可以完全正確的將真的判為真,假的判為假
  2. 訓練後1-15s:D網路分類正確率保持全對
  3. 15s之後:“D網路對ground truth的分類正確率”和“D網路對G網路輸出的分類正確率”出現震盪,表明在這個階段G網路已經能夠以假亂真,D網路將部分G網路輸出判為真,同時也將部分ground truth判為假。

上述3個階段就體現出對抗訓練的特點,G網路和D網路互為對手,互相提高對方的訓練難度,最終得到符合預期的模型。

接下來再從資料上給一個直觀的認識

Ground truth: 在公差為1的等差數列上加入stddev=0.3, mean=0的正態分佈噪聲後,得到的一組Ground Truth資料如下

[ 1.1539436 ] 
[ 2.08863655] 
[ 2.78491645] 
[ 3.93027817] 
[ 4.75851967] 
[ 5.88655699] 
[ 7.10540526] 
[ 7.43159023] 
[ 9.19373617] 
[ 10.08779359]

訓練開始前G網路的資料

基本無規律,和輸入噪聲分佈接近

[ 1.15080559] 
[ 0.66351247] 
[-0.39484465] 
[-0.41690648] 
[ 0.29061955] 
[ 0.06131642] 
[-2.46439648] 
[-1.53692639] 
[-0.30550677] 
[-0.89200932]

迭代100次之後G網路的輸出 
出現等差數列的端倪

[ -0.53692651] 
[ 0.86063552] 
[ 2.47294378] 
[ 5.24512053] 
[ 7.7618413 ] 
[ 9.57867622] 
[ 9.15039253] 
[ 9.86567402] 
[ 10.62975025] 
[ 10.24322414]

迭代500次之後G網路的輸出 
除了最後一個元素,前9個元素已經基本符合預期

[ 1.09549832] 
[ 2.21490908] 
[ 2.95311546] 
[ 4.06684017] 
[ 4.96308947] 
[ 6.03393888] 
[ 6.89026165] 
[ 7.93375683] 
[ 8.63552094] 
[ 9.07077026]

迭代1500次之後G網路的輸出 
已經足以以假亂真

[ 0.07186054] 
[ 1.08289695] 
[ 2.55904818] 
[ 4.07374573] 
[ 5.14763832] 
[ 6.07010031] 
[ 6.79585028] 
[ 8.17086124] 
[ 8.81297684] 
[ 10.38190079]

更多資料

本文首發於:。部落格中的內容體系性不如在知乎整理的清楚,但會隨時記錄工作中的技術問題和發現,如有興趣歡迎圍觀。

相關推薦

[TensorFlow]生成對抗網路(GAN)介紹實踐

主旨本文簡要介紹了生成對抗網路(GAN)的原理,接下來通過tensorflow開發程式實現生成對抗網路(GAN),並且通過實現的GAN完成對等差數列的生成和識別。通過對設計思路和實現方案的介紹,本文可以輔助讀者理解GAN的工作原理,並掌握實現方法。有了這樣的基礎,在面對工作中實際問題時可以將GAN納入考慮,選

生成對抗網路(GAN)的理論應用完整入門介紹

文章來源:https://blog.csdn.net/blood0604/article/details/73635586?locationNum=1&fps=1本文包含以下內容:1.為什麼生成模型值得研究2.生成模型的分類3.GAN相對於其他生成模型相比有什麼優勢4

簡單理解實驗生成對抗網路GAN

之前 GAN網路是近兩年深度學習領域的新秀,火的不行,本文旨在淺顯理解傳統GAN,分享學習心得。現有GAN網路大多數程式碼實現使用python、torch等語言,這裡,後面用matlab搭建一個簡單的GAN網路,便於理解GAN原理。 分享一個目前各類G

ICCV2017 | 一文詳解GAN之父Ian Goodfellow 演講《生成對抗網路的原理應用》(附完整PPT)

當地時間 10月 22 日到10月29日,兩年一度的計算機視覺國際頂級會議 International Conference on Computer Vision(ICCV 2017)在義大利威尼斯開幕。Google Brain 研究科學家Ian Goodfellow在會上作為主題為《生成對抗網路(G

洞見 | 生成對抗網路GAN最近在NLP領域有哪些應用

分享一下我老師大神的人工智慧教程!零基礎,通俗易懂!http://blog.csdn.net/jiangjunshow 也歡迎大家轉載本篇文章。分享知識,造福人民,實現我們中華民族偉大復興!        

生成對抗網路GAN資料打包分享

  全文摘要 生成式對抗網路,即所謂的GAN是近些年來最火的無監督學習方法之一,模型由Goodfellow等人在2014年首次提出,將博弈論中非零和博弈思想與生成模型結合在一起,巧妙避開了傳統生成模型中概率密度估計困難等問題,是生成模型達到良好的效果。本文總結收集了

生成對抗網路GAN系列(六)--- CycleGAN---文末附程式碼

Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks Jun-Yan Zhu      Taesung Park          Phillip Isola   

生成對抗網路(GAN)的變體pix2pix思想

1.概述 pix2pix 是對抗生成網路的一種變體,它的結構類似於CGAN,但又有別於CGAN。先來說一下它能做哪些事情,顧名思義就是將一張圖片轉成另一張圖片(千萬不要理解成畫素變畫素啊),或者說將一個場景轉換成另一場景。pix2pix 能做的事情有很多,比如說

生成對抗網路GAN(一) 簡介和變種

基本概念[1] 目標函式 零和遊戲(zero-sum game) 納什均衡 minimax演算法 GAN借鑑了零和遊戲的思想,引入生成網路和辨別網路,讓兩個網路互相博弈,當辨別網路不能辨別資料來自於真實分佈還是生成網路的時候,此時的生成網路可以

洞見 | 生成對抗網路GAN最近在NLP領域有哪些應用?

剛做完實驗,來答一答自然語言處理方面GAN的應用。 直接把GAN應用到NLP領域(主要是生成序列),有兩方面的問題: 1. GAN最開始是設計用於生成連續資料,但是自然語言處理中我們要用來生成離散tokens的序列。因為生成器(Generator,簡稱G)需要利用從

白話生成對抗網路 GAN,50 行程式碼玩轉 GAN 模型!【附原始碼】

今天,紅色石頭帶大家一起來了解一下如今非常火熱的深度學習模型:生成對抗網路(Generate Adversarial Network,GAN)。GAN 非常有趣,我就以最直白的語言來講解它,最後實現一個簡單的 GAN 程式來幫助大家加深理解。 1. 什

生成對抗網路(GAN)簡單梳理

0 前言 GAN(Generative Adversarial Nets)是用對抗方法來生成資料的一種模型。和其他機器學習模型相比,GAN引人注目的地方在於給機器學習引入了對抗這一理念。 回溯地球生物的進化路線就會發現,萬物都是在不停的和其他事物對抗

生成對抗網路GAN的前世今生

綜合論文摘要以及網上文章整理。 2014年,蒙特利爾大學的Ian Goodfellow和他的同事創造了生成式對抗網路(GAN) 論文:Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adve

生成對抗網路GAN---生成mnist手寫數字影象示例(附程式碼)

Ian J. Goodfellow等人於2014年在論文Generative Adversarial Nets中提出了一個通過對抗過程估計生成模型的新框架。框架中同時訓練兩個模型:一個生成模型(generative model)G,用來捕獲資料分佈;一個判別模型(discri

生成對抗網路(GAN)的一些知識整理(課件)

無監督學習是機器學習的未來,而現在GAN的出現,則為無監督學習帶來了光明。 鑑於GAN的火熱,最近將從一些大牛分享資料中擷取和整理的資料附圖如下: 最近測試了一下tensorflow環境下gan

生成對抗網路——GAN(一)

Generative adversarial network 據有關媒體統計:CVPR2018的論文裡,有三分之一的論文與GAN有關 由此可見,GAN在視覺領域的未來多年內,將是一片沃土(CVer們是時候入門GAN了)。而發現這片礦源的就是GAN之父,Goodf

在瀏覽器中進行深度學習:TensorFlow.js (八)生成對抗網路GAN

Generative Adversarial Network 是深度學習中非常有趣的一種方法。GAN最早源自Ian Goodfellow的這篇論文。LeCun對GAN給出了極高的評價: “There are many interesting recent development in deep learni

一篇讀懂生成對抗網路GAN)原理+tensorflow程式碼實現

作者:JASON 2017.10.15   生成對抗網路GAN(Generative adversarial networks)是最近很火的深度學習方法,要理解它可以把它分成生成模型和判別模型兩個部分,簡單來說就是:兩個人比賽,看是 A 的矛厲害,還是 B

tensorflow 1.01中GAN(生成對抗網路)手寫字型生成例子(MINST)的測試

為了更好地掌握GAN的例子,從網上找了段程式碼進行跑了下,測試了效果。具體過程如下: 程式碼檔案如下: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data i

生成對抗網路的簡單介紹TensorFlow 程式碼)

原文地址: 引言 最近,研究者們對生成模型的興趣一直很大(參見OpenAI的這篇部落格文章)。這些生成模型是可以學習建立類似於我們給它們的資料。這樣的直觀感受是,如果我們可以得到一個能寫出高質量的新聞文章的模型,那麼它一般也會學到很多關於新聞文章的內