[TensorFlow]生成對抗網路(GAN)介紹與實踐
主旨
本文簡要介紹了生成對抗網路(GAN)的原理,接下來通過tensorflow開發程式實現生成對抗網路(GAN),並且通過實現的GAN完成對等差數列的生成和識別。通過對設計思路和實現方案的介紹,本文可以輔助讀者理解GAN的工作原理,並掌握實現方法。有了這樣的基礎,在面對工作中實際問題時可以將GAN納入考慮,選擇最合適的演算法。
程式碼和執行環境
TensorFlow版本
>>> tf.version
‘1.1.0-rc2’
背景知識
Generative Adversarial Nets[1][https://arxiv.org/pdf/1406.2661v1.pdf]是Ian J. Goodfellow等在2014年提出的一種訓練模型的方法,此方法通過兩個網路(生成網路G和分類網路D)對抗訓練,得到符合預期目標的生成模型和分類模型。
要理解GAN的原理,上述論文是最好的教材。但考慮到原文首先是英文撰寫,其次包含不少數學推導,新手上手並不容易。因此筆者這裡班門弄斧,基於論文簡單轉述GAN的設計思想要點
GAN的目標,給定一個真實樣本(本文也稱之為ground truth)集合,訓練出兩個模型,一個能夠從噪聲訊號生成儘可能像ground truth的樣本;另一個能夠判斷給定樣本是否是ground truth。兩個模型詳細介紹如下
- 生成模型:論文中稱為generative model,本文稱為G網路或G模型。G網路的輸入是噪聲訊號(例如均勻分佈的隨機數),輸出為形狀與真實樣本ground truth一致。G網路的訓練目標是,儘可能輸出與ground truth相似的樣本。這裡“相似”定義為:如果G網路生成的一個樣本騙過了D網路,使得D網路誤以為這就是真實樣本,則就是相似的,G網路獲得獎勵;反之,獲得懲罰。
- 分類模型:論文中稱為discriminative model,本文稱為D網路或D模型。D網路是一個2分類器,輸入為ground truth或者G網路生成的樣本,輸出為TRUE或FALSE:TRUE表示D網路認為當前輸入樣本是ground truth,FALSE表示D網路認為當前輸入樣本是G網路生成的“偽造”樣本。D網路的訓練目標是儘可能正確的區分開ground truth和G網路生成的“偽造”樣本。
從上述討論可以看出,G網路和D網路是兩個目標完全相反的網路,G網路盡其所能“偽造”出像真實樣本的資料,D網路儘可能區分真實與偽造資料。GAN中所謂“對抗”的概念,即來源於此。
GAN的訓練過程就是G和D兩個網路互相對抗的過程,對抗的結果是G網路被訓練到能夠生成以假亂真的樣本,即G網路從噪聲輸入得到了儘可能與真實樣本相似的輸出,或者說G學會了從噪聲生成ground truth的方法;D網路也可以區分ground truth與其他樣本,即D學會了區分ground truth與其他資料的方法。
參考文獻
1. Goodfellow I J, Pougetabadie J, Mirza M, et al. Generative adversarial nets[C]. neural information processing systems, 2014: 2672-2680.
神經網路設計和實現
問題構造
在開始設計神經網路之前,我們首先構造出預期GAN解決的問題。前述GAN論文中提出了一個從噪聲學習正態分佈的經典問題,讀者如果在網路上搜索GAN的案例,除了影象識別,基本上只有這麼一個問題和方案實現。
本文重新設計了一個與論文中不同的問題。問題描述如下
- Ground Truth定義:[1,2,3,4,5,6,7,8,9,10]構成的等差數列,為了適當降低學習難度,此數列每個元素與噪聲相加,噪聲為0均值正態分佈隨機變數,標準差取0.1, 0.03, 0等不同數值
- 輸入噪聲定義: [-1,1]之間均勻分佈的隨機變數。
網路結構設計
G網路:參考論文資料,我們選擇多層全連線神經網路
D網路:由於要分辨的是等差數列,我們選擇RNN作為D網路。
網路結構如下(下圖是tensorboard生成的計算圖):圖中”G_net”表示G網路,”D_net”/”D_net_1”表示D網路,雖然圖中D網路被分成了兩份,但是其RNN引數是共享的,即圖中正下方”rnn”這個單元。
程式碼實現
G網路定義
# generative network
# use multi-layer percepton to generate time sequence from random noise
# input tensor must be in shape of (batch_size, self.seq_len)
def generator(self, inputTensor):
with tf.name_scope('G_net'):
gInputTensor = tf.identity(inputTensor, name='input')
# Multilayer percepton implementation
numNodesInEachLayer = 10
numLayers = 3
previous_output_tensor = gInputTensor
for layerIdx in range(numLayers):
activation,z = self.fullConnectedLayer(previous_output_tensor, numNodesInEachLayer, layerIdx)
previous_output_tensor = activation
g_logit = z
g_logit = tf.identity(g_logit, 'g_logit')
return g_logit
G網路損失函式
下面程式碼片段中self.d_logit_fake是D網路對G網路生成資料的判定結果。由於G網路的目標是儘可能騙過D網路,如果D網路對於G網路生成資料全部判為1(即TRUE),則損失最小,反之,損失最大。
g_loss_d = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=self.d_logit_fake,
labels=tf.ones(shape=[self.batch_size_t,1])
),
name='g_loss_d'
)
D網路的定義
RNN+全連線輸出層,無論是RNN還是全連線層都必須在對ground truth和G生成樣本之間共享同一套引數
def discriminator(self, inputTensor,reuseParam):
with tf.name_scope('D_net'):
num_units_in_LSTMCell = 10
# RNN definition
with tf.variable_scope('d_rnn'):
lstmCell = tf.contrib.rnn.BasicLSTMCell(num_units_in_LSTMCell,reuse=reuseParam)
init_state = lstmCell.zero_state(self.batch_size_t, dtype=tf.float32)
raw_output, final_state = tf.nn.dynamic_rnn(lstmCell, inputTensor, initial_state=init_state)
rnn_output_list = tf.unstack(tf.transpose(raw_output, [1, 0, 2]), name='outList')
rnn_output_tensor = rnn_output_list[-1];
# Full connected network
numberOfInputDims = inputTensor.shape[1].value
numOfNodesInLayer = 1
if not reuseParam:
self.d_w = tf.Variable(initial_value=tf.random_normal([numberOfInputDims, numOfNodesInLayer]),
name=('dnet_w_1'))
self.d_b = tf.Variable(tf.zeros([1, numOfNodesInLayer]), name='dnet_b_1')
self.d_z = tf.matmul(rnn_output_tensor,self.d_w) + self.d_b
self.d_z = tf.identity(self.d_z, name='dnet_z_1')
d_sigmoid = tf.nn.sigmoid(self.d_z, name='dnet_a_1')
d_logit = self.d_z
d_logit = tf.identity(d_logit, 'd_net_logit')
return d_logit
D網路損失函式
D網路使用同一套引數分辨兩種輸入,一種是ground truth,另一種是G網路的輸出。對於ground truth,訓練目標為儘可能判為1,對於G網路的輸出,訓練目標為儘可能判為0,因此Loss函式定義如下
# For D-network, jduge ground truth to TRUE, jduge G-network output to FALSE,making loss low
d_loss_ground_truth = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=self.d_logit_gnd_truth,
labels=tf.ones(shape=[self.batch_size_t,1])
),
name='d_loss_gnd'
)
d_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=self.d_logit_fake,
labels=tf.zeros(shape=[self.batch_size_t,1])
),
name='d_loss_fake'
)
d_loss = d_loss_ground_truth + d_loss_fake
對抗訓練
對抗訓練中,G網路Loss值只用來調整G網路引數,D網路Loss值只用來調整D網路引數
g_net_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G_net')
g_net_var_list = g_net_var_list + tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='g_rnn')
self.train_g = tf.train.AdamOptimizer(self.lr_g).minimize(g_loss,var_list=g_net_var_list)
d_net_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='D_net')
d_net_var_list = d_net_var_list + tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='d_rnn')
self.train_d = tf.train.AdamOptimizer(self.lr_d).minimize(d_loss,var_list=d_net_var_list)
訓練效果
下圖是訓練過程中D網路對ground truth和G網路輸出的分類正確率曲線
從圖中可以看到3個階段
- 訓練開始後一秒左右:“D網路對ground truth的分類正確率”和“D網路對G網路輸出的分類正確率”都快速上升到100%,即D網路經過訓練可以完全正確的將真的判為真,假的判為假
- 訓練後1-15s:D網路分類正確率保持全對
- 15s之後:“D網路對ground truth的分類正確率”和“D網路對G網路輸出的分類正確率”出現震盪,表明在這個階段G網路已經能夠以假亂真,D網路將部分G網路輸出判為真,同時也將部分ground truth判為假。
上述3個階段就體現出對抗訓練的特點,G網路和D網路互為對手,互相提高對方的訓練難度,最終得到符合預期的模型。
接下來再從資料上給一個直觀的認識
Ground truth: 在公差為1的等差數列上加入stddev=0.3, mean=0的正態分佈噪聲後,得到的一組Ground Truth資料如下
[ 1.1539436 ]
[ 2.08863655]
[ 2.78491645]
[ 3.93027817]
[ 4.75851967]
[ 5.88655699]
[ 7.10540526]
[ 7.43159023]
[ 9.19373617]
[ 10.08779359]
訓練開始前G網路的資料
基本無規律,和輸入噪聲分佈接近
[ 1.15080559]
[ 0.66351247]
[-0.39484465]
[-0.41690648]
[ 0.29061955]
[ 0.06131642]
[-2.46439648]
[-1.53692639]
[-0.30550677]
[-0.89200932]
迭代100次之後G網路的輸出
出現等差數列的端倪
[ -0.53692651]
[ 0.86063552]
[ 2.47294378]
[ 5.24512053]
[ 7.7618413 ]
[ 9.57867622]
[ 9.15039253]
[ 9.86567402]
[ 10.62975025]
[ 10.24322414]
迭代500次之後G網路的輸出
除了最後一個元素,前9個元素已經基本符合預期
[ 1.09549832]
[ 2.21490908]
[ 2.95311546]
[ 4.06684017]
[ 4.96308947]
[ 6.03393888]
[ 6.89026165]
[ 7.93375683]
[ 8.63552094]
[ 9.07077026]
迭代1500次之後G網路的輸出
已經足以以假亂真
[ 0.07186054]
[ 1.08289695]
[ 2.55904818]
[ 4.07374573]
[ 5.14763832]
[ 6.07010031]
[ 6.79585028]
[ 8.17086124]
[ 8.81297684]
[ 10.38190079]
更多資料
本文首發於:。部落格中的內容體系性不如在知乎整理的清楚,但會隨時記錄工作中的技術問題和發現,如有興趣歡迎圍觀。
相關推薦
[TensorFlow]生成對抗網路(GAN)介紹與實踐
主旨本文簡要介紹了生成對抗網路(GAN)的原理,接下來通過tensorflow開發程式實現生成對抗網路(GAN),並且通過實現的GAN完成對等差數列的生成和識別。通過對設計思路和實現方案的介紹,本文可以輔助讀者理解GAN的工作原理,並掌握實現方法。有了這樣的基礎,在面對工作中實際問題時可以將GAN納入考慮,選
生成對抗網路(GAN)的理論與應用完整入門介紹
文章來源:https://blog.csdn.net/blood0604/article/details/73635586?locationNum=1&fps=1本文包含以下內容:1.為什麼生成模型值得研究2.生成模型的分類3.GAN相對於其他生成模型相比有什麼優勢4
簡單理解與實驗生成對抗網路GAN
之前 GAN網路是近兩年深度學習領域的新秀,火的不行,本文旨在淺顯理解傳統GAN,分享學習心得。現有GAN網路大多數程式碼實現使用python、torch等語言,這裡,後面用matlab搭建一個簡單的GAN網路,便於理解GAN原理。 分享一個目前各類G
ICCV2017 | 一文詳解GAN之父Ian Goodfellow 演講《生成對抗網路的原理與應用》(附完整PPT)
當地時間 10月 22 日到10月29日,兩年一度的計算機視覺國際頂級會議 International Conference on Computer Vision(ICCV 2017)在義大利威尼斯開幕。Google Brain 研究科學家Ian Goodfellow在會上作為主題為《生成對抗網路(G
洞見 | 生成對抗網路GAN最近在NLP領域有哪些應用
分享一下我老師大神的人工智慧教程!零基礎,通俗易懂!http://blog.csdn.net/jiangjunshow 也歡迎大家轉載本篇文章。分享知識,造福人民,實現我們中華民族偉大復興!  
生成對抗網路GAN資料打包分享
全文摘要 生成式對抗網路,即所謂的GAN是近些年來最火的無監督學習方法之一,模型由Goodfellow等人在2014年首次提出,將博弈論中非零和博弈思想與生成模型結合在一起,巧妙避開了傳統生成模型中概率密度估計困難等問題,是生成模型達到良好的效果。本文總結收集了
生成對抗網路GAN系列(六)--- CycleGAN---文末附程式碼
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks Jun-Yan Zhu Taesung Park Phillip Isola
生成對抗網路(GAN)的變體pix2pix思想
1.概述 pix2pix 是對抗生成網路的一種變體,它的結構類似於CGAN,但又有別於CGAN。先來說一下它能做哪些事情,顧名思義就是將一張圖片轉成另一張圖片(千萬不要理解成畫素變畫素啊),或者說將一個場景轉換成另一場景。pix2pix 能做的事情有很多,比如說
生成對抗網路GAN(一) 簡介和變種
基本概念[1] 目標函式 零和遊戲(zero-sum game) 納什均衡 minimax演算法 GAN借鑑了零和遊戲的思想,引入生成網路和辨別網路,讓兩個網路互相博弈,當辨別網路不能辨別資料來自於真實分佈還是生成網路的時候,此時的生成網路可以
洞見 | 生成對抗網路GAN最近在NLP領域有哪些應用?
剛做完實驗,來答一答自然語言處理方面GAN的應用。 直接把GAN應用到NLP領域(主要是生成序列),有兩方面的問題: 1. GAN最開始是設計用於生成連續資料,但是自然語言處理中我們要用來生成離散tokens的序列。因為生成器(Generator,簡稱G)需要利用從
白話生成對抗網路 GAN,50 行程式碼玩轉 GAN 模型!【附原始碼】
今天,紅色石頭帶大家一起來了解一下如今非常火熱的深度學習模型:生成對抗網路(Generate Adversarial Network,GAN)。GAN 非常有趣,我就以最直白的語言來講解它,最後實現一個簡單的 GAN 程式來幫助大家加深理解。 1. 什
生成對抗網路(GAN)簡單梳理
0 前言 GAN(Generative Adversarial Nets)是用對抗方法來生成資料的一種模型。和其他機器學習模型相比,GAN引人注目的地方在於給機器學習引入了對抗這一理念。 回溯地球生物的進化路線就會發現,萬物都是在不停的和其他事物對抗
生成對抗網路GAN的前世今生
綜合論文摘要以及網上文章整理。 2014年,蒙特利爾大學的Ian Goodfellow和他的同事創造了生成式對抗網路(GAN) 論文:Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adve
生成對抗網路GAN---生成mnist手寫數字影象示例(附程式碼)
Ian J. Goodfellow等人於2014年在論文Generative Adversarial Nets中提出了一個通過對抗過程估計生成模型的新框架。框架中同時訓練兩個模型:一個生成模型(generative model)G,用來捕獲資料分佈;一個判別模型(discri
生成對抗網路(GAN)的一些知識整理(課件)
無監督學習是機器學習的未來,而現在GAN的出現,則為無監督學習帶來了光明。 鑑於GAN的火熱,最近將從一些大牛分享資料中擷取和整理的資料附圖如下: 最近測試了一下tensorflow環境下gan
生成對抗網路——GAN(一)
Generative adversarial network 據有關媒體統計:CVPR2018的論文裡,有三分之一的論文與GAN有關 由此可見,GAN在視覺領域的未來多年內,將是一片沃土(CVer們是時候入門GAN了)。而發現這片礦源的就是GAN之父,Goodf
在瀏覽器中進行深度學習:TensorFlow.js (八)生成對抗網路 (GAN
Generative Adversarial Network 是深度學習中非常有趣的一種方法。GAN最早源自Ian Goodfellow的這篇論文。LeCun對GAN給出了極高的評價: “There are many interesting recent development in deep learni
一篇讀懂生成對抗網路(GAN)原理+tensorflow程式碼實現
作者:JASON 2017.10.15 生成對抗網路GAN(Generative adversarial networks)是最近很火的深度學習方法,要理解它可以把它分成生成模型和判別模型兩個部分,簡單來說就是:兩個人比賽,看是 A 的矛厲害,還是 B
tensorflow 1.01中GAN(生成對抗網路)手寫字型生成例子(MINST)的測試
為了更好地掌握GAN的例子,從網上找了段程式碼進行跑了下,測試了效果。具體過程如下: 程式碼檔案如下: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data i
生成對抗網路的簡單介紹(TensorFlow 程式碼)
原文地址: 引言 最近,研究者們對生成模型的興趣一直很大(參見OpenAI的這篇部落格文章)。這些生成模型是可以學習建立類似於我們給它們的資料。這樣的直觀感受是,如果我們可以得到一個能寫出高質量的新聞文章的模型,那麼它一般也會學到很多關於新聞文章的內