從頭學pytorch(十九):批量歸一化batch normalization
批量歸一化
論文地址:https://arxiv.org/abs/1502.03167
批量歸一化基本上是現在模型的標配了.
說實在的,到今天我也沒搞明白batch normalize能夠使得模型訓練更穩定的底層原因,要徹底搞清楚,涉及到很多凸優化的理論,需要非常紮實的數學基礎才行.
目前為止,我理解的批量歸一化即把每一層輸入的特徵,統一變換到統一的尺度上來,避免各個特徵的單位不統一的情況.即把每一個特徵的分佈都轉變為均值為0,方差為1的分佈.
然後在變換後的資料的基礎上加一個線性變換.
關於batch normalize的常見問題,參考:https://zhuanlan.zhihu.com/p/55852062
對全連線層做批量歸一化
我們先考慮如何對全連線層做批量歸一化。通常,我們將批量歸一化層置於全連線層中的仿射變換和啟用函式之間。設全連線層的輸入為\(\boldsymbol{u}\),權重引數和偏差引數分別為\(\boldsymbol{W}\)和\(\boldsymbol{b}\),啟用函式為\(\phi\)。設批量歸一化的運算子為\(\text{BN}\)。那麼,使用批量歸一化的全連線層的輸出為
\[\phi(\text{BN}(\boldsymbol{x})),\]
其中批量歸一化輸入\(\boldsymbol{x}\)由仿射變換
\[\boldsymbol{x} = \boldsymbol{W\boldsymbol{u} + \boldsymbol{b}}\]
得到。考慮一個由\(m\)個樣本組成的小批量,仿射變換的輸出為一個新的小批量\(\mathcal{B} = \{\boldsymbol{x}^{(1)}, \ldots, \boldsymbol{x}^{(m)} \}\)。它們正是批量歸一化層的輸入。對於小批量\(\mathcal{B}\)中任意樣本\(\boldsymbol{x}^{(i)} \in \mathbb{R}^d, 1 \leq i \leq m\),批量歸一化層的輸出同樣是\(d\)維向量
\[\boldsymbol{y}^{(i)} = \text{BN}(\boldsymbol{x}^{(i)}),\]
並由以下幾步求得。首先,對小批量\(\mathcal{B}\)求均值和方差:
\[\boldsymbol{\mu}_\mathcal{B} \leftarrow \frac{1}{m}\sum_{i = 1}^{m} \boldsymbol{x}^{(i)},\]
\[\boldsymbol{\sigma}_\mathcal{B}^2 \leftarrow \frac{1}{m} \sum_{i=1}^{m}(\boldsymbol{x}^{(i)} - \boldsymbol{\mu}_\mathcal{B})^2,\]
其中的平方計算是按元素求平方。接下來,使用按元素開方和按元素除法對\(\boldsymbol{x}^{(i)}\)標準化:
\[\hat{\boldsymbol{x}}^{(i)} \leftarrow \frac{\boldsymbol{x}^{(i)} - \boldsymbol{\mu}_\mathcal{B}}{\sqrt{\boldsymbol{\sigma}_\mathcal{B}^2 + \epsilon}},\]
這裡\(\epsilon > 0\)是一個很小的常數,保證分母大於0。在上面標準化的基礎上,批量歸一化層引入了兩個可以學習的模型引數,拉伸(scale)引數 \(\boldsymbol{\gamma}\) 和偏移(shift)引數 \(\boldsymbol{\beta}\)。這兩個引數和\(\boldsymbol{x}^{(i)}\)形狀相同,皆為\(d\)維向量。這就是文章開頭說的對特徵做normalization後,再做一次線性變換
它們與\(\boldsymbol{x}^{(i)}\)分別做按元素乘法(符號\(\odot\))和加法計算:
\[{\boldsymbol{y}}^{(i)} \leftarrow \boldsymbol{\gamma} \odot \hat{\boldsymbol{x}}^{(i)} + \boldsymbol{\beta}.\]
至此,我們得到了\(\boldsymbol{x}^{(i)}\)的批量歸一化的輸出\(\boldsymbol{y}^{(i)}\)。
值得注意的是,可學習的拉伸和偏移引數保留了不對\(\hat{\boldsymbol{x}}^{(i)}\)做批量歸一化的可能:此時只需學出\(\boldsymbol{\gamma} = \sqrt{\boldsymbol{\sigma}_\mathcal{B}^2 + \epsilon}\)和\(\boldsymbol{\beta} = \boldsymbol{\mu}_\mathcal{B}\)。我們可以對此這樣理解:如果批量歸一化無益,理論上,學出的模型可以不使用批量歸一化。
對卷積層做批量歸一化
對卷積層來說,批量歸一化發生在卷積計算之後、應用啟用函式之前。如果卷積計算輸出多個通道,我們需要對這些通道的輸出分別做批量歸一化,且每個通道都擁有獨立的拉伸和偏移引數,並均為標量。設小批量中有\(m\)個樣本。在單個通道上,假設卷積計算輸出的高和寬分別為\(p\)和\(q\)。我們需要對該通道中\(m \times p \times q\)個元素同時做批量歸一化。對這些元素做標準化計算時,我們使用相同的均值和方差,即該通道中\(m \times p \times q\)個元素的均值和方差。
用個更具體點的例子總結一下就是:
對於全連線層,假設輸出shape為[batch,256],那歸一化即對256列的每一列求平均.
對於卷積層,假設輸出shape為[batch,96,5,5],即對每個樣本來說,有96個5x5的feature map,歸一化在96個channel上分別做歸一化,均值為batchx5x5個數的均值.
預測時的批量歸一化
這時候,還有一個問題,就是模型訓練好了,傳入輸入,計算前向傳播的結果,也是要做歸一化的處理的.那這時候我用的均值和方差應該是多少呢? 很顯然,不應該是某個batch的樣本的均值和方差,而應該是所有樣本的均值和方差.因為gamma和beta的更新是不斷累積的結果,而不是僅僅參考某一個batch的輸入.(注意,這裡的樣本不是指模型的輸入圖片矩陣,而是指歸一化層的輸入,這個輸入隨著訓練的進行是在不斷變化的,而且不同的歸一化層的輸入是不一樣的).所以,在做batch normalize的時候,我們還要維護一個值,用於估計全部樣本的均值,方差.一種常見的方法是移動平均法.
可以通過下面的測試程式碼看一下moving_mean是如何逼近3的
momentum=0.9
moving_mean = 0.0
for epoch in range(10):
for mean in [1,2,3,4,5]:
moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
print(moving_mean)
至於為何不直接對均值之和求平均,我在torch論壇提問了,目前還沒回復.
現在我們來總結一下batch normalize的計算過程,然後實現它.分為訓練/測試兩個部分.
訓練:
- 求輸入x的均值
- 求輸入x的方差
- 將x歸一化
\[\hat{\boldsymbol{x}}^{(i)} \leftarrow \frac{\boldsymbol{x}^{(i)} - \boldsymbol{\mu}_\mathcal{B}}{\sqrt{\boldsymbol{\sigma}_\mathcal{B}^2 + \epsilon}},\] - 對歸一化後的x做線性變換
\[{\boldsymbol{y}}^{(i)} \leftarrow \boldsymbol{\gamma} \odot \hat{\boldsymbol{x}}^{(i)} + \boldsymbol{\beta}.\]
測試:
- 使用移動平均所得的均值和方差,計算歸一化的值
- 對歸一化後的值做線性變換
那麼可以寫出BatchNorm的定義
def batch_norm(is_training,X,eps,gamma,beta,running_mean,running_var,alpha):
assert len(X.shape) in (2,4)
if is_training:
#X [batch,n]
if len(X.shape) == 2:
mean = X.mean(dim=0)
var = ((X-mean) ** 2).mean(dim=0)
else:
#X [batch,c,h,w]
mean = X.mean(dim=0,keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
var = ((X-mean) ** 2).mean(dim=0,keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
X_hat = (X - mean) / torch.sqrt(var + eps)
running_mean = alpha * mean + (1 - alpha) * running_mean
running_var = alpha * var + (1 - alpha) * running_var
else:
X_hat = (X - running_mean) / torch.sqrt(running_var + eps)
#print(gamma.shape,X_hat.shape,beta.shape)
Y = gamma * X_hat + beta #
return Y,running_mean,running_var
class BatchNorm(nn.Module):
def __init__(self,is_conv,in_channels):
super(BatchNorm,self).__init__()
#卷積層/全連線層歸一化後的線性變換引數.
if not is_conv:
# x:[batch,n]
shape = (1,in_channels)
self.gamma = nn.Parameter(torch.ones(shape)) #是可學習的引數.反向傳播時需要根據梯度更新.
self.beta = nn.Parameter(torch.zeros(shape)) #是可學習的引數.反向傳播時需要根據梯度更新.
self.running_mean = torch.zeros(shape) #不需要求梯度.在forward時候更新.
self.running_var = torch.zeros(shape) #不需要求梯度.在forward時候更新.
else:
# x:[btach,c,h,w]
shape = (1,in_channels,1,1)
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.ones(shape))
self.running_mean = torch.zeros(shape)
self.running_var = torch.zeros(shape)
self.eps = 1e-5
self.momentum=0.9
def forward(self,x):
# 如果X不在記憶體上,將moving_mean和moving_var複製到X所在視訊記憶體上
if self.running_mean.device != x.device:
self.running_mean = self.running_mean.to(x.device)
self.running_var = self.running_var.to(x.device)
# self.training繼承自nn.Module,預設true,呼叫.eval()會設定成false
if self.training:
Y,self.running_mean,self.running_var = batch_norm(True,x,self.eps,self.gamma,self.beta,self.running_mean,self.running_var,self.momentum)
else:
Y,self.running_mean,self.running_var = batch_norm(False,x,self.eps,self.gamma,self.beta,self.running_mean,self.running_var,self.momentum)
return Y
BatchNorm繼承自nn.Module,含有可學習引數gamma
,beta
,反向傳播時會更新他們. 引數running_mean
,running_var
在前向傳播時計算.
batch_norm
需要區分是卷積後的歸一化還是全連線後的歸一化.卷積的歸一化是對每個channel單獨求均值.
資料載入
batch_size,num_workers=16,2
train_iter,test_iter = learntorch_utils.load_data(batch_size,num_workers,None)
模型定義
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
BatchNorm(is_conv=True,in_channels=6),
nn.Sigmoid(),
nn.MaxPool2d(2, 2), # kernel_size, stride
nn.Conv2d(6, 16, 5),
BatchNorm(is_conv=True,in_channels=16),
nn.Sigmoid(),
nn.MaxPool2d(2, 2)
)
self.fc = nn.Sequential(
nn.Linear(16*4*4, 120),
BatchNorm(is_conv=False,in_channels=120),
nn.Sigmoid(),
nn.Linear(120, 84),
BatchNorm(is_conv=False,in_channels = 84),
nn.Sigmoid(),
nn.Linear(84, 10)
)
def forward(self, img):
feature = self.conv(img)
output = self.fc(feature.view(img.shape[0], -1))
return output
net = LeNet().cuda()
損失函式定義
l = nn.CrossEntropyLoss()
優化器定義
opt = torch.optim.Adam(net.parameters(),lr=0.01)
評估函式定義
def test():
acc_sum = 0
batch = 0
for X,y in test_iter:
X,y = X.cuda(),y.cuda()
y_hat = net(X)
acc_sum += (y_hat.argmax(dim=1) == y).float().sum().item()
batch += 1
print('acc:%f' % (acc_sum/(batch*batch_size)))
訓練
num_epochs=5
def train():
for epoch in range(num_epochs):
train_l_sum,batch=0,0
start = time.time()
for X,y in train_iter:
X,y = X.cuda(),y.cuda() #把tensor放到視訊記憶體
y_hat = net(X) #前向傳播
loss = l(y_hat,y) #計算loss,nn.CrossEntropyLoss中會有softmax的操作
opt.zero_grad()#梯度清空
loss.backward()#反向傳播,求出梯度
opt.step()#根據梯度,更新引數
train_l_sum += loss.item()
batch += 1
end = time.time()
time_per_epoch = end - start
print('epoch %d,train_loss %f,time %f' % (epoch + 1,train_l_sum/(batch*batch_size),time_per_epoch))
test()
train()
加了BN層以後,視訊記憶體直接不夠用了.但是用torch自己的nn.BatchNorm2d
和nn.BatchNorm1d
就沒有問題.應該還是自己的對BatchNorm的實現哪裡不夠好.
使用torch自己的BatchNorm的實現定義的模型如下:
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
nn.BatchNorm2d(6),
nn.Sigmoid(),
nn.MaxPool2d(2, 2), # kernel_size, stride
nn.Conv2d(6, 16, 5),
nn.BatchNorm2d(16),
nn.Sigmoid(),
nn.MaxPool2d(2, 2)
)
self.fc = nn.Sequential(
nn.Linear(16*4*4, 120),
nn.BatchNorm1d(120),
nn.Sigmoid(),
nn.Linear(120, 84),
nn.BatchNorm1d(84),
nn.Sigmoid(),
nn.Linear(84, 10)
)
def forward(self, img):
feature = self.conv(img)
output = self.fc(feature.view(img.shape[0], -1))
return output
net = LeNet().cuda()
訓練輸出如下:
epoch 1,batch_size 4,train_loss 0.194394,time 50.538379
acc:0.789400
epoch 2,batch_size 4,train_loss 0.146268,time 52.352518
acc:0.789500
epoch 3,batch_size 4,train_loss 0.132021,time 52.240710
acc:0.820600
epoch 4,batch_size 4,train_loss 0.126241,time 53.277958
acc:0.824400
epoch 5,batch_size 4,train_loss 0.120607,time 52.067259
acc:0.831800