1. 程式人生 > >從零上手變分自編碼器(VAE)

從零上手變分自編碼器(VAE)

閱讀更多,歡迎關注公眾號:論文收割機(paper_reader)

  • Kingma D P, Welling M. Auto-encoding variational bayes[J]. arXiv preprint arXiv:1312.6114, 2013.
  • Rezende D J, Mohamed S, Wierstra D. Stochastic backpropagation and approximate inference in deep generative models[J]. arXiv preprint arXiv:1401.4082, 2014.

關於VAE,網上的教程很多,但是通俗易懂又直擊本質的中文教程幾乎沒有。VAE雖然是一個深度學習模型,但在設計過程中利用了很多概率統計知識。很多教程僅僅是長篇大論地推導公式,或是直接指出它作為深度模型的結構和約束,然而對於VAE的設計動機,卻很少有教程提及。

這篇文章不僅同時介紹了VAE作為深度模型和概率模型的兩方面特徵,還對作者進行設計的動機進行了仔細的解釋。筆者水平有限錯誤難以避免,但是希望借撰寫這篇文章的機會,能和讀者一起更好地瞭解這一模型。

1. VAE的適用情境和動機

VAE (Variational Autocoder, 變分自編碼器) 同GAN一樣,是近年來一種常用的深度生成模型。該模型最早分別由Kingma等人和Rezende等人在ICLR 14' 和ICML14' 上獨立發表。

  • https://arxiv.org/abs/1312.6114 

  • https://arxiv.org/abs/1401.4082

本文主要參照了Kingma的文章,對VAE的設計動機進行分析。

VAE作為一個生成模型,它的主要目的,就是為資料的隱變數加上先驗。為了描述這一動機,我們可以參考如下的圖模型:

其中,x是與訓練資料對應的隨機變數,而z對應的是影響資料分佈的那些隱變數。

例如,對於手寫數字識別這一問題,x對應的是資料集中的一張張圖片,而z對應的是資料的標籤,也就是這張圖片所顯示的數字是多少。顯而易見,對於手寫數字識別這個問題,每張圖片中圖案的形狀,顯然是由寫字者想表達的數字來決定的。

或者說,每張寫有數字“5” 的圖片在形式上都大同小異,因為它們都想表達出 “數字5” 這一資訊。而對於整個資料集來說,每張圖片所表達的資訊,也都是“數字0” ,“數字1” ,“數字2” …… “數字0”這10類其中之一。

因此,引數theta所對應的概率分佈ptheta(z),代表的是資料集內圖片所表示的各種數字的聯合概率分佈情況,隨機變數z代表代表單張圖片所代表的數字,而x則代表這張圖片各個點的畫素值。

同樣地,在其他任務中,z與x也可以被賦予其他不同的含義,例如z與x分別對應“語義—文字”, “疾病—症狀” ,或是單純用z來代表資料的某個表達,也都是可以的。

而作為無監督的生成模型,VAE的應用條件也非常寬鬆,只要有無標註的訓練資料,再為隱變數設定一個先驗概率分佈,就可以進行訓練了。訓練完成後,我們能得到的結果也是兩個:

  1. 對訓練資料的模擬,作為生成模型,VAE能生成與訓練資料同分布的結果。

  2. 隱變數本身。VAE能為資料集中每個資料生成一個隱變數z,通過對隱變數分佈情況的分析,可以得到一些關於資料分佈情況的結論。

在下兩章,我們將從深度學習和概率模型兩個視角討論VAE的設計邏輯。對於只關心VAE長什麼樣子,該怎麼使用的讀者,可以跳過這兩章閱讀。

2. 深度學習視角的自編碼器

自編碼器是一種無監督的深度學習模型,用來學習訓練資料的某種壓縮表達。最基本的自編碼器結構非常簡單。如下圖所示,自編碼器由編碼器和解碼器兩部分組成,通過編碼器可以學出對原資料的壓縮表達,再通過解碼器復原出原資料。

這裡得到的z確實是原資料x的壓縮表達,但z卻很難被看成是x對應的隱變數,因為第一,z和x之間不存在明顯的因果關係,第二,z也不服從我們對隱變數的先驗分佈ptheta(z)。因此,VAE做出了兩個改進。

第一個改進是,傳統自編碼器中都是直接生成z的值,而VAE是先生成一個關於z的概率分佈q(z|x),再用這個分佈生成z。

第二個改進是,將隱變數z的分佈情況也以KL散度的的形式加進模型的優化目標loss裡

這樣當模型優化完成後,隱變數z的分佈與先驗分佈之間的KL散度近似到0,我們就可以說隱變數z服從我們的先驗了。

而這兩個改進又帶來了新的問題:從分佈中取樣這個操作是無法求偏導的,利用BP演算法優化模型的時候,梯度傳到這裡就無法再計算下去了。作者在文章的後面利用Reparameterization trick解決了這一問題。

3. 概率模型視角的變分推斷

VAE中實際上使用的概率圖模型如下:

除了theta之外,圖中引入了另外一個引數phi。這個模型的意義是這樣的:資料x基於隱變數z生成,隱變數z服從引數為theta的分佈ptheta(z),但由於p(x)和p(z|x)難以估計,使得theta的值難以優化。

比如說,通過患者的症狀估計患者所患的疾病,雖然這兩者之間存在因果關係,但實際上兩者之間關係是十分複雜的。

因此,我們引入另外一個引數phi和對應的分佈qphi(z|x),這是一個對複雜情況的簡單估計。分佈q往往是數學上容易計算的某些分佈,phi是這個分佈的引數。我們通過資料x來計算出phi,就能得到真實因果關係的一個近似。

在VAE中,通過x得到分佈qphi(z|x)再得到z的這一過程被看做是編碼器,而通過z還原x的這一過程,也就是從分佈ptheta(x|z)中得出x的過程,被看做是解碼器

引數phi和theta就是深度網路模型中各個神經網路單元的引數,也就是我們要優化的引數。優化模型中引數phi和theta的過程如下,注意這裡引入了變分下限:

這裡應用了變分推斷的技巧,等式左側是資料集內某一條資料出現的概率的對數,是一個無法被計算的常數,等式右側第一項KL散度表示qphi(z|x)對ptheta(z|x)的逼近程度,是我們想要優化(最小化)的物件,第二項往往被稱作evidencelower bound (ELBO,變分下限)。優化的目標是使變分下限最大,這樣在等式左側不變的情況下,KL散度就能取到最小。

而在論文中,作者給出了該情況下變分下限的兩種形式:

分別記為(1)式和(2)式。原本來說,對於這樣的公式,只要對引數求偏導就可以優化了。然而對於這個兩個式子,直接求偏導進行優化會導致結果的極大波動,被認為是不可行的。

論文中原文如下:The usual (naïve) MonteCarlo gradient estimator for this type of problem exhibits exhibits very highvariance and is impractical for our purposes. 

對於這一問題的具體分析,有興趣的讀者可以根據原文中提供的參考文獻進一步瞭解。

筆者認為,這一問題的原因和深度學習視角中模型無法利用BP演算法的原因是對應的。在深度模型中出現無法求偏導的取樣操作,在圖模型推導中,也對應出現了一個難以優化的概率期望E。為了解決這個問題,作者想出了一個叫做Reparameterization trick的技巧,使得對於某些常見的先驗q(z|x)和p(z)(例如高斯分佈),模型可以高效簡潔地進行推斷。

4. Reparameterization trick和兩個視角的統一

以高斯分佈為例,Reparameterization trick簡單來說是這樣的。

考慮從一個分佈中抽樣的過程:

這個過程顯然是不可導的,但如果是從一個正態分佈中抽樣,該分佈的均值和方差分別為μ和σ的話,上面的抽樣過程就可以變換成下面的形式:

顯然這兩種抽樣方法得到的z是同分布的。但是第二種方法顯然可以針對引數μ和σ求偏導數。同樣,將z=μ + σε帶入變分下限的表示式中,也可以最優化變分下限。

論文中給出了應用Reparameterization trick後,對於變分下限的(1)形式,變分下限的近似值LA:

其中,函式g是應用Reparameterization trick後,通過x與ε計算z的關係式(在正態分佈下就是z=μ+ σε),l是對於一條資料而言取樣ε(其實也就是z)的次數。

在實際應用中資料量夠大時,對每條資料取樣1次就可以了。

之後可以進一步估計整個資料集中的邊際似然下界:

其中N是資料集中資料個數,M是minibatch的大小。

論文中還給出了訓練演算法的minibatch形式:

然而,無論是在論文的示例中還是其他什麼地方,上述的這些複雜的訓練方法都很少被用到,換句話說,上面那些都沒什麼卵用:)。

真正被使用的是變分下限的(2)形式,也就是這個:

帶入Reparameterization trick之後形式是這樣的:

這個式子中的KL散度一項,在應用Reparameterization trick之後,往往能直接計算(論文後文只給出了當p和q都是高斯分佈時的推導,其他分佈時的推導還是要靠自己啊),而後面的期望一項,不正是訓練自編碼器時候的重建誤差嗎?

論文中還說,這個方法因為需要取樣估計的地方更少,所以和上一個方法比起來還更容易收斂。

所以繞了一大圈,我們又回到了原點。VAE的主要結構如下(以高斯分佈為例):

其中μ和σ是在先驗分佈q為高斯分佈時的引數,ε是從標準高斯分佈中取樣的結果。這個模型的loss也分為兩部分組成,每部分形式和意義如下:

訓練完成後,可以將隱變數z的分佈做視覺化,也可以利用解碼器來生成更多資料。

5. 例項:手寫數字識別

在這個例子