1. 程式人生 > >從頭學pytorch(十九):批量歸一化batch normalization

從頭學pytorch(十九):批量歸一化batch normalization

批量歸一化

論文地址:https://arxiv.org/abs/1502.03167
批量歸一化基本上是現在模型的標配了.
說實在的,到今天我也沒搞明白batch normalize能夠使得模型訓練更穩定的底層原因,要徹底搞清楚,涉及到很多凸優化的理論,需要非常紮實的數學基礎才行.
目前為止,我理解的批量歸一化即把每一層輸入的特徵,統一變換到統一的尺度上來,避免各個特徵的單位不統一的情況.即把每一個特徵的分佈都轉變為均值為0,方差為1的分佈.
然後在變換後的資料的基礎上加一個線性變換.
關於batch normalize的常見問題,參考:https://zhuanlan.zhihu.com/p/55852062

對全連線層做批量歸一化

我們先考慮如何對全連線層做批量歸一化。通常,我們將批量歸一化層置於全連線層中的仿射變換和啟用函式之間。設全連線層的輸入為\(\boldsymbol{u}\),權重引數和偏差引數分別為\(\boldsymbol{W}\)和\(\boldsymbol{b}\),啟用函式為\(\phi\)。設批量歸一化的運算子為\(\text{BN}\)。那麼,使用批量歸一化的全連線層的輸出為

\[\phi(\text{BN}(\boldsymbol{x})),\]

其中批量歸一化輸入\(\boldsymbol{x}\)由仿射變換

\[\boldsymbol{x} = \boldsymbol{W\boldsymbol{u} + \boldsymbol{b}}\]

得到。考慮一個由\(m\)個樣本組成的小批量,仿射變換的輸出為一個新的小批量\(\mathcal{B} = \{\boldsymbol{x}^{(1)}, \ldots, \boldsymbol{x}^{(m)} \}\)。它們正是批量歸一化層的輸入。對於小批量\(\mathcal{B}\)中任意樣本\(\boldsymbol{x}^{(i)} \in \mathbb{R}^d, 1 \leq i \leq m\),批量歸一化層的輸出同樣是\(d\)維向量

\[\boldsymbol{y}^{(i)} = \text{BN}(\boldsymbol{x}^{(i)}),\]

並由以下幾步求得。首先,對小批量\(\mathcal{B}\)求均值和方差:

\[\boldsymbol{\mu}_\mathcal{B} \leftarrow \frac{1}{m}\sum_{i = 1}^{m} \boldsymbol{x}^{(i)},\]
\[\boldsymbol{\sigma}_\mathcal{B}^2 \leftarrow \frac{1}{m} \sum_{i=1}^{m}(\boldsymbol{x}^{(i)} - \boldsymbol{\mu}_\mathcal{B})^2,\]

其中的平方計算是按元素求平方。接下來,使用按元素開方和按元素除法對\(\boldsymbol{x}^{(i)}\)標準化:

\[\hat{\boldsymbol{x}}^{(i)} \leftarrow \frac{\boldsymbol{x}^{(i)} - \boldsymbol{\mu}_\mathcal{B}}{\sqrt{\boldsymbol{\sigma}_\mathcal{B}^2 + \epsilon}},\]

這裡\(\epsilon > 0\)是一個很小的常數,保證分母大於0。在上面標準化的基礎上,批量歸一化層引入了兩個可以學習的模型引數,拉伸(scale)引數 \(\boldsymbol{\gamma}\) 和偏移(shift)引數 \(\boldsymbol{\beta}\)。這兩個引數和\(\boldsymbol{x}^{(i)}\)形狀相同,皆為\(d\)維向量。這就是文章開頭說的對特徵做normalization後,再做一次線性變換
它們與\(\boldsymbol{x}^{(i)}\)分別做按元素乘法(符號\(\odot\))和加法計算:

\[{\boldsymbol{y}}^{(i)} \leftarrow \boldsymbol{\gamma} \odot \hat{\boldsymbol{x}}^{(i)} + \boldsymbol{\beta}.\]

至此,我們得到了\(\boldsymbol{x}^{(i)}\)的批量歸一化的輸出\(\boldsymbol{y}^{(i)}\)。
值得注意的是,可學習的拉伸和偏移引數保留了不對\(\hat{\boldsymbol{x}}^{(i)}\)做批量歸一化的可能:此時只需學出\(\boldsymbol{\gamma} = \sqrt{\boldsymbol{\sigma}_\mathcal{B}^2 + \epsilon}\)和\(\boldsymbol{\beta} = \boldsymbol{\mu}_\mathcal{B}\)。我們可以對此這樣理解:如果批量歸一化無益,理論上,學出的模型可以不使用批量歸一化。

 對卷積層做批量歸一化

對卷積層來說,批量歸一化發生在卷積計算之後、應用啟用函式之前。如果卷積計算輸出多個通道,我們需要對這些通道的輸出分別做批量歸一化,且每個通道都擁有獨立的拉伸和偏移引數,並均為標量。設小批量中有\(m\)個樣本。在單個通道上,假設卷積計算輸出的高和寬分別為\(p\)和\(q\)。我們需要對該通道中\(m \times p \times q\)個元素同時做批量歸一化。對這些元素做標準化計算時,我們使用相同的均值和方差,即該通道中\(m \times p \times q\)個元素的均值和方差。

用個更具體點的例子總結一下就是:
對於全連線層,假設輸出shape為[batch,256],那歸一化即對256列的每一列求平均.
對於卷積層,假設輸出shape為[batch,96,5,5],即對每個樣本來說,有96個5x5的feature map,歸一化在96個channel上分別做歸一化,均值為batchx5x5個數的均值.

預測時的批量歸一化

這時候,還有一個問題,就是模型訓練好了,傳入輸入,計算前向傳播的結果,也是要做歸一化的處理的.那這時候我用的均值和方差應該是多少呢? 很顯然,不應該是某個batch的樣本的均值和方差,而應該是所有樣本的均值和方差.因為gamma和beta的更新是不斷累積的結果,而不是僅僅參考某一個batch的輸入.(注意,這裡的樣本不是指模型的輸入圖片矩陣,而是指歸一化層的輸入,這個輸入隨著訓練的進行是在不斷變化的,而且不同的歸一化層的輸入是不一樣的).所以,在做batch normalize的時候,我們還要維護一個值,用於估計全部樣本的均值,方差.一種常見的方法是移動平均法.

可以通過下面的測試程式碼看一下moving_mean是如何逼近3的

momentum=0.9
moving_mean = 0.0
for epoch in range(10):
    for mean in [1,2,3,4,5]:
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        print(moving_mean)

至於為何不直接對均值之和求平均,我在torch論壇提問了,目前還沒回復.

現在我們來總結一下batch normalize的計算過程,然後實現它.分為訓練/測試兩個部分.
訓練:

  • 求輸入x的均值
  • 求輸入x的方差
  • 將x歸一化
    \[\hat{\boldsymbol{x}}^{(i)} \leftarrow \frac{\boldsymbol{x}^{(i)} - \boldsymbol{\mu}_\mathcal{B}}{\sqrt{\boldsymbol{\sigma}_\mathcal{B}^2 + \epsilon}},\]
  • 對歸一化後的x做線性變換
    \[{\boldsymbol{y}}^{(i)} \leftarrow \boldsymbol{\gamma} \odot \hat{\boldsymbol{x}}^{(i)} + \boldsymbol{\beta}.\]

測試:

  • 使用移動平均所得的均值和方差,計算歸一化的值
  • 對歸一化後的值做線性變換

那麼可以寫出BatchNorm的定義

def batch_norm(is_training,X,eps,gamma,beta,running_mean,running_var,alpha):
    assert len(X.shape) in (2,4)
    if is_training:
        #X [batch,n]
        if len(X.shape) == 2:
            mean = X.mean(dim=0)
            var = ((X-mean) ** 2).mean(dim=0)
        else:
        #X [batch,c,h,w]
            mean = X.mean(dim=0,keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
            var = ((X-mean) ** 2).mean(dim=0,keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
    
        X_hat = (X - mean) / torch.sqrt(var + eps)
        running_mean = alpha * mean + (1 - alpha) * running_mean
        running_var = alpha * var + (1 - alpha) * running_var
    else:
        X_hat = (X - running_mean) / torch.sqrt(running_var + eps)
    
    #print(gamma.shape,X_hat.shape,beta.shape)
    Y = gamma * X_hat + beta  #

    return Y,running_mean,running_var

class BatchNorm(nn.Module):
    def __init__(self,is_conv,in_channels):
        super(BatchNorm,self).__init__()
        #卷積層/全連線層歸一化後的線性變換引數.
        if not is_conv:
            # x:[batch,n]
            shape = (1,in_channels)
            self.gamma = nn.Parameter(torch.ones(shape)) #是可學習的引數.反向傳播時需要根據梯度更新.
            self.beta = nn.Parameter(torch.zeros(shape)) #是可學習的引數.反向傳播時需要根據梯度更新.
            self.running_mean = torch.zeros(shape) #不需要求梯度.在forward時候更新.
            self.running_var = torch.zeros(shape) #不需要求梯度.在forward時候更新.
        else:
            # x:[btach,c,h,w]
            shape = (1,in_channels,1,1)
            self.gamma = nn.Parameter(torch.ones(shape))
            self.beta = nn.Parameter(torch.ones(shape))
            self.running_mean = torch.zeros(shape)
            self.running_var = torch.zeros(shape)

        self.eps = 1e-5
        self.momentum=0.9

    def forward(self,x):
        # 如果X不在記憶體上,將moving_mean和moving_var複製到X所在視訊記憶體上
        if self.running_mean.device != x.device:
            self.running_mean = self.running_mean.to(x.device)
            self.running_var = self.running_var.to(x.device)

        # self.training繼承自nn.Module,預設true,呼叫.eval()會設定成false
        if self.training:
            Y,self.running_mean,self.running_var = batch_norm(True,x,self.eps,self.gamma,self.beta,self.running_mean,self.running_var,self.momentum)
        else:
            Y,self.running_mean,self.running_var = batch_norm(False,x,self.eps,self.gamma,self.beta,self.running_mean,self.running_var,self.momentum)
        
        return Y

BatchNorm繼承自nn.Module,含有可學習引數gamma,beta,反向傳播時會更新他們. 引數running_mean,running_var在前向傳播時計算.    
batch_norm需要區分是卷積後的歸一化還是全連線後的歸一化.卷積的歸一化是對每個channel單獨求均值.

資料載入

batch_size,num_workers=16,2
train_iter,test_iter = learntorch_utils.load_data(batch_size,num_workers,None)

模型定義

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
            BatchNorm(is_conv=True,in_channels=6),
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2), # kernel_size, stride
            nn.Conv2d(6, 16, 5),
            BatchNorm(is_conv=True,in_channels=16),
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(16*4*4, 120),
            BatchNorm(is_conv=False,in_channels=120),
            nn.Sigmoid(),
            nn.Linear(120, 84),
            BatchNorm(is_conv=False,in_channels = 84),
            nn.Sigmoid(),
            nn.Linear(84, 10)
        )

    def forward(self, img):
        feature = self.conv(img)
        output = self.fc(feature.view(img.shape[0], -1))
        return output

net = LeNet().cuda()

損失函式定義

l = nn.CrossEntropyLoss()

優化器定義

opt = torch.optim.Adam(net.parameters(),lr=0.01)

評估函式定義

def test():
    acc_sum = 0
    batch = 0
    for X,y in test_iter:
        X,y = X.cuda(),y.cuda()
        y_hat = net(X)
        acc_sum += (y_hat.argmax(dim=1) == y).float().sum().item()
        batch += 1
    print('acc:%f' % (acc_sum/(batch*batch_size)))

訓練

num_epochs=5
def train():
    for epoch in range(num_epochs):
        train_l_sum,batch=0,0
        start = time.time()
        for X,y in train_iter:
            X,y = X.cuda(),y.cuda() #把tensor放到視訊記憶體
            y_hat = net(X)  #前向傳播
            loss = l(y_hat,y) #計算loss,nn.CrossEntropyLoss中會有softmax的操作
            opt.zero_grad()#梯度清空
            loss.backward()#反向傳播,求出梯度
            opt.step()#根據梯度,更新引數

            train_l_sum += loss.item()
            batch += 1
        end = time.time()
        time_per_epoch =  end - start
        print('epoch %d,train_loss %f,time %f' % (epoch + 1,train_l_sum/(batch*batch_size),time_per_epoch))
        test()

train()

加了BN層以後,視訊記憶體直接不夠用了.但是用torch自己的nn.BatchNorm2dnn.BatchNorm1d就沒有問題.應該還是自己的對BatchNorm的實現哪裡不夠好.

使用torch自己的BatchNorm的實現定義的模型如下:

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
            nn.BatchNorm2d(6),
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2), # kernel_size, stride
            nn.Conv2d(6, 16, 5),
            nn.BatchNorm2d(16),
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(16*4*4, 120),
            nn.BatchNorm1d(120),
            nn.Sigmoid(),
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.Sigmoid(),
            nn.Linear(84, 10)
        )

    def forward(self, img):
        feature = self.conv(img)
        output = self.fc(feature.view(img.shape[0], -1))
        return output

net = LeNet().cuda()

訓練輸出如下:

epoch 1,batch_size 4,train_loss 0.194394,time 50.538379
acc:0.789400
epoch 2,batch_size 4,train_loss 0.146268,time 52.352518
acc:0.789500
epoch 3,batch_size 4,train_loss 0.132021,time 52.240710
acc:0.820600
epoch 4,batch_size 4,train_loss 0.126241,time 53.277958
acc:0.824400
epoch 5,batch_size 4,train_loss 0.120607,time 52.067259
acc:0.831800