1. 程式人生 > >你瞭解變分自編碼器嗎? 請看這裡

你瞭解變分自編碼器嗎? 請看這裡

10.9  變分自編碼器
前面所描述的自編碼器可以降維重構樣本,在這基礎上我們來學習一個更強大的自編碼器。
10.9.1  什麼是變分自編碼器
變分自編碼器學習的不再是樣本的個體,而是要學習樣本的規律。這樣訓練出來的自編碼器不單單具有重構樣本的功能,還具有了仿照樣本的功能。
聽起來這麼強大的功能,到底是怎麼做到的?下面我們來講講它的原理。
變分自編碼器,其實就是在編碼過程中改變了樣本的分佈(“變分”可以理解為改變分佈)。前文中所說的“學習樣本的規律”,具體指的就是樣本的分佈,假設我們知道樣本的分佈函式,就可以從這個函式中隨便取一個樣本,然後進行網路解碼層前向傳導,就可以生成了一個新的樣本。
為了得到這個樣本的分佈函式,我們的模型訓練目的將不再是樣本本身,而是通過加一個約束項,將我們的網路生成一個服從於高斯分佈的資料集,這樣按照高斯分佈裡的均值和方差規則可以任意取相關的資料,然後通過解碼層還原成樣本。
10.9.2  例項82:使用變分自解碼器模擬生成MNIST資料

對於變分自解碼器,好多文獻都是給出了一堆晦澀難懂的公式,其實裡面真正的公式只有一個——KL離散度的計算。而它也屬於成熟的式子,就跟交叉熵一樣,直接拿來用就可以。
公式本來是語言的高度概括,而一篇文章全是公式沒有語言就會令人難以理解。本文只會有程式碼加上語言描述,不會讓這部分知識讀起來感覺晦澀。
程式碼例子共分如下幾個步驟,下面我們就來一一操作。
案例描述
使用變分自編碼模型進行模擬MNIST資料的生成。
1.引入庫,定義佔位符

這次建立的網路與以前略有不同,編碼為兩個全連線層由784到256再到兩個2層的並列輸出,然後將兩個輸出通過一個公式的計算,輸入到以一個2節點為開始的解碼部分,接著2個全連線層又2到256再到784。如圖10-17


                                             圖10-17 變分解碼器層次
具體的計算公式,後文會有詳細介紹。
在下面的程式碼中與前面程式碼不同,下面引入了一個scipy庫,在後面視覺化時會用到。標頭檔案引入之後,定義操作符x和z。x用於原始的圖片輸入,z用於中間節點解碼器的輸入。

程式碼10-8  變分自編碼器


zinput是個佔位符,在後面要通過它將分佈資料輸入,用來生成模擬樣本資料。
2.定義學習引數
由於這次的網路結構不同,所以定義的引數也有變化,mean_w1與mean_b1是生成mean的權重,log_sigma_w1與log_sigma_b1是生成log_sigma的權重。

程式碼10-8  變分自編碼器(續)




3.定義網路結構
按照上面圖10-16的描述,網路節點可以按照以下程式碼來定義,在變分解碼器為訓練的中間節點賦予了特殊的意義,讓它們代表均值和方差,並將他們所代表的資料集向著標準高斯分佈資料集靠近(也就是原始資料是樣本,高斯分佈資料是標籤),然後可以使用kl散度公式,來計算它所代表的集合與標準的高斯分佈集合(均值是0,方差為1的正態分佈)間的距離,將這個距離當成誤差讓它最小化從而來優化網路引數。
這裡的方差節點不是真正意義的方差,是取了log之後的。所以會有tf.exp(z_log_sigma_sq)的變換,是取得方差的值,再tf.sqrt將其開平方得到標準差。用符合標準正太分佈的一個數來乘上標準差加上均值,就使這個數成為符合(z_mean,sigma)資料分佈集合裡面的一個點(z_mean是指網路生成均值,sigma是指網路生成的z_log_sigma_sq變換後的值)。


到此,完成了編碼階段。將原始資料編碼輸出3個值:
● 一個是該表述資料分佈的均值,
● 一個是表述該資料分佈的方差,
● 還有一個是得到了該資料分佈中的一個實際的點z。


程式碼10-8  變分自編碼器(續)


得到了符合原資料集上的一個具體點z之後,就可以通過神經網路這個點z還原成原始資料reconstruction了。這個解碼部分還是和以前的內容一樣,參照編碼的網路逐層還原回去。
h2out和reconstructionout兩個節點不屬於訓練中的結構,是為了生成指定資料時用的。
4.構建模型的反向傳播
和以往一樣,需要定義損失函式的節點和優化演算法的op,程式碼如下。

程式碼10-8  變分自編碼器(續)


上面程式碼描述了網路兩個優化方向:
● 一個是比較生成的資料分佈與標準高斯分佈的距離,這裡使用KL離散度的公式(見latent_loss)。
● 另一個是計算生成資料與原始資料間的損失,這裡用的是平方差,也可以用交叉熵。
最後將兩種損失值放在一起,通過adam的隨機梯度下降演算法來實現在訓練中的優化引數。
5.設定引數,進行訓練
這步驟與前面類似,設定訓練引數,迭代50次,在session中每次迴圈取指定批次資料進行訓練。

程式碼10-8  變分自編碼器(續)


視覺化部分這裡不再詳述,可以參考本書的配套程式碼,最終程式執行的結果輸出如下,結果如圖10-18所示。



可以看到生成的數字,不再一味單純的學習形狀,而是通過資料分佈的方式學習規則,對原有圖片具有更清晰的修正功能。

仿照前面的視覺化程式碼,將均值和方差代表的二維資料在直角座標系中展現如下:


 
                                                           圖10-19變分自解碼二維視覺化
從圖10-19中可以看出,具有代表同一數值的圖片的特徵資料分佈還是比較集中的,說明變分位元組碼也具有降維功能,也可以用它進行分類任務的資料降維預處理部分。
6.高斯分佈取樣,生成模擬資料
為了進一步證實模型學到資料分佈的情況,我們這次在高斯分佈中抽樣去取一些點,將其對映到模型中的z,然後通過解碼部分還原成真實圖片看看效果,程式碼如下。

注意:


程式碼10-8  變分自編碼器(續)


執行以上程式碼生成如圖10-20所示圖片。


                                              圖10-20 變分自解碼生成模擬資料

可以看到,在神經網路的世界裡,所以左下角到右上角顯示了網路是按照圖片的形狀變化而排列的,並不像我們人類一樣,把數字按照1到9的排列,因為機器學的只是圖片,而人類對數字的理解更多的是在於它幕後的意思。

更多章節請購買《深入學習之 TensorFlow 入門、原理與進階實戰》全本