1. 程式人生 > >【貝葉斯分析⑦】高斯過程

【貝葉斯分析⑦】高斯過程

貝葉斯框架下, 可以用高斯過程來估計一個函式 f : R→R. 對於每個xi, f(xi)可以用一個均值方差暫未知的高斯分佈來建模。因為連續空間的xi可以有無限個,擬合一個函式的高斯過程其實一個無限維的多元高斯。實際中,不管是我們的給定資料{(x, y)},還是測試點{x*}的個數都是有限的。因此無論是高斯過程先驗還是還是高斯過程後驗都是有限維的。因為多元高斯分佈的任意有限自己還是多遠高斯分佈,所以先驗和後驗也都可以都用高斯建模。

實踐中,高斯過程先驗(一個多元高斯)的均值通常統一設為0(因為我們通常對於需要擬合的函式取值並沒有太多先驗知識),而協方差矩陣則用核函式K(x,x')來建模(i.e., cov[yi, yj] = K(xi, xj) = η exp(-ρ * SED(xi, xj)) )。協方差描述的是,當x變化時,y是如何變化的.而核函式(高斯核)的特點是x,x'越近,返回值越大。用高斯核作協方差可以理解為一個小的x擾動會導致較小的y變化,一個大的x變化會導致較大的y變化。

此外,現實中觀測值一定會受到噪聲的影響,因此我們對觀測值建模需要引入一項高斯白噪聲:yi ~ f(xi) +  εi. ε ~ N(0, σ)。這樣協方差函式就改寫成 cov[yi, yj] = K(xi, xj) + σ δij.

注意高斯過程的先驗建模中有引數θ = {η, ρ, σ}, 但為了強調高斯過程是一個非參方法,這些引數稱為“超參”。這些引數將用貝葉斯推斷來估計。

總結一下,我們有各種假設先驗(θ的先驗分佈,f(x) ~ N(0, K(x, x'))),高斯似然 y | θ ~ N(0, K(x, x) + σ I), 然後根據y的觀測值可以推斷出超參的分佈。預測就更簡單了,高斯過程的一個優點就是後驗分佈可以直接求得解析解,要預測的點記為(x*, f(x*)), 後驗分佈可得 f(x*) | x, x*, y ~ N(μ*, Σ*), 其中 μ* = K(X*, X) (K(x, x) + σ I)^(-1) y, Σ* = K(x*, x*) - K(x*, x) (K(x, x) + σ I)^(-1) K(x, x*).詳細推導可見Gaussian Processes for Machine Learning一書的Sec. 2.2, 此外可以記住的一個常見結論如下:

然後是程式碼實踐,用20個帶噪的正弦函式觀測點來擬合函式:

#%matplotlib inline
import pymc3 as pm
import numpy as np
import theano.tensor as tt
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist


if __name__ == "__main__":
    np.random.seed(1)
    squared_distance = lambda x, y: cdist(x.reshape(-1,1), y.reshape(-1,1)) ** 2 #SED function

    N = 20         # number of training points.
    n = 100         # number of test points.
    
    np.random.seed(1)
    f = lambda x: np.sin(x).flatten()

    x = np.random.uniform(0, 10, size=N)
    y = np.random.normal(np.sin(x), np.sqrt(0.01))

    plt.plot(x, y, 'o')
    plt.xlabel('$x$', fontsize=16)
    plt.ylabel('$f(x)$', fontsize=16, rotation=0)
    
    with pm.Model() as GP:
        mu = np.zeros(N)
        eta = pm.HalfCauchy('eta', 0.1)
        rho = pm.HalfCauchy('rho', 1)
        sigma = pm.HalfCauchy('sigma', 1)
        
        D = squared_distance(x, x) #SED(x,x)
        
        K = tt.fill_diagonal(eta * pm.math.exp(-rho * D), eta + sigma) #(K(x, x) + σ I)
        
        obs = pm.MvNormal('obs', mu, cov=K, observed=y)
        
    
        test_points = np.linspace(0, 10, 100)
        D_pred = squared_distance(test_points, test_points) #SED(x*,x*)
        D_off_diag = squared_distance(x, test_points) #SED(x,x*) n * N
        
        K_oo = eta * pm.math.exp(-rho * D_pred) #K(x*,x*)
        K_o = eta * pm.math.exp(-rho * D_off_diag) #K(x,x*)

        inv_K = tt.nlinalg.matrix_inverse(K)
        
        mu_post = pm.Deterministic('mu_post', pm.math.dot(pm.math.dot(K_o.T, inv_K), y))
        SIGMA_post = pm.Deterministic('SIGMA_post', K_oo - pm.math.dot(pm.math.dot(K_o.T, inv_K), K_o))        

        step = pm.Metropolis()                
        start = pm.find_MAP()
        trace = pm.sample(1000, step = step, start=start, nchains = 1)        
        varnames = ['eta', 'rho', 'sigma']
        chain = trace[100:]
        pm.traceplot(chain, varnames)
        
        plt.figure()
        y_pred = [np.random.multivariate_normal(m, S) for m,S in zip(chain['mu_post'], chain['SIGMA_post'])]        
        for yp in y_pred:
            plt.plot(test_points, yp, 'r-', alpha=0.1)
        
        plt.plot(x, y, 'bo')
        plt.xlabel('$x$', fontsize=16)
        plt.ylabel('$f(x)$', fontsize=16, rotation=0)

輸出: