用PyTorch實現一個卷積神經網路進行影象分類
阿新 • • 發佈:2019-02-15
1. 回顧
在進入這一篇部落格的內容之前,我們先確保已經成功安裝好PyTorch,可以參考我之前的一篇部落格“Ubuntu12.04下PyTorch詳細安裝記錄”:
http://blog.csdn.net/wblgers1234/article/details/72902016
- 1
接下來,我們用設計一個簡單的卷積神經網路的方式來熟悉PyTorch的用法。
2. 設計卷積神經網路
在設計複雜的神經網路之前,我們依然考慮按照斯坦福大學的“UFLDL Tutorial”的CNN部分來構建一個簡單的卷積神經網路,即按照以下的設計:
輸入層->二維特徵卷積->sigmoid激勵->均值池化->全連線網路->softmax輸出
- 1
按照下面的程式碼對應來看神經網路的結構。註釋得很清晰,有不清楚的可以留言,這裡就不再贅述。
class CNN_net(nn.Module):
def __init__(self):
# 先執行nn.Module的初始化函式
super(CNN_net, self).__init__()
# 卷積層的定義,輸入為1channel的灰度圖,輸出為4特徵,每個卷積kernal為9*9
self.conv = nn.Conv2d(1, 4, 9)
# 均值池化
self.pool = nn.AvgPool2d(2 , 2)
# 全連線後接softmax
self.fc = nn.Linear(10*10*4, 10)
self.softmax = nn.Softmax()
def forward(self, x):
# 卷積層,分別是二維卷積->sigmoid激勵->池化
out = self.conv(x)
out = F.sigmoid(out)
out = self.pool(out)
print(out.size())
# 將特徵的維度進行變化(batchSize*filterDim*featureDim*featureDim->batchSize*flat_features)
out = out.view(-1, self.num_flat_features(out))
# 全連線層和softmax處理
out = self.fc(out)
out = self.softmax(out)
return out
def num_flat_features(self, x):
# 四維特徵,第一維是batchSize
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
3. 資料準備
還記得torchvision嗎?我們在做和影象有關的實驗時會更多地與它打交道。這次我們選擇最簡單也是最廣為人知的MNIST資料庫來訓練和測試CNN。同時在torchvision中有一個torchvision.datasets,它為很多常用的影象資料庫提供介面,其中就包括MNIST。
from torchvision.datasets import MNIST
- 1
需要先下載MNIST,並且轉換為PyTorch可以識別的資料格式:
# MNIST影象資料的轉換函式
trans_img = transforms.Compose([
transforms.ToTensor()
])
# 下載MNIST的訓練集和測試集
trainset = MNIST('./MNIST', train=True, transform=trans_img, download=True)
testset = MNIST('./MNIST', train=False, transform=trans_img, download=True)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
我們檢視transforms.ToTensor()的解釋,將原本的二維影象格式轉換為PyTorch的基本單位torch.FloatTensor。
Converts a PIL.Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
- 1
4. 訓練和測試
4.1 訓練資料集
從程式碼中可以清晰的看見“前向傳播”,“反向傳播”,optimizer的求解。
# 訓練過程
for i in range(epoches):
running_loss = 0.
running_acc = 0.
for (img, label) in trainloader:
# 轉換為Variable型別
img = Variable(img)
label = Variable(label)
optimizer.zero_grad()
# feedforward
output = net(img)
loss = criterian(output, label)
# backward
loss.backward()
optimizer.step()
# 記錄當前的lost以及batchSize資料對應的分類準確數量
running_loss += loss.data[0]
_, predict = torch.max(output, 1)
correct_num = (predict == label).sum()
running_acc += correct_num.data[0]
# 計算並列印訓練的分類準確率
running_loss /= len(trainset)
running_acc /= len(trainset)
print("[%d/%d] Loss: %.5f, Acc: %.2f" %(i+1, epoches, running_loss, 100*running_acc))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
在訓練完成之後,有一個處理很重要,需要將當前的網路設定為“測試模式”,然後才可以進行測試集的驗證。
# 將當前模型設定到測試模式
net.eval()
- 1
- 2
4.2 測試資料集
在測試過程中,只有“前向傳播”過程對輸入的影象進行分類預測。
# 測試過程
testloss = 0.
testacc = 0.
for (img, label) in testloader:
# 轉換為Variable型別
img = Variable(img)
label = Variable(label)
# feedforward
output = net(img)
loss = criterian(output, label)
# 記錄當前的lost以及累加分類正確的樣本數
testloss += loss.data[0]
_, predict = torch.max(output, 1)
num_correct = (predict == label).sum()
testacc += num_correct.data[0]
# 計算並列印測試集的分類準確率
testloss /= len(testset)
testacc /= len(testset)
print("Test: Loss: %.5f, Acc: %.2f %%" %(testloss, 100*testacc))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
4.3 程式碼執行結果
從下面的結果,可以看到迭代10次的訓練分類準確率和測試分類準確率:
CNN_net (
(conv): Conv2d(1, 4, kernel_size=(9, 9), stride=(1, 1))
(pool): AvgPool2d (
)
(fc): Linear (400 -> 10)
(softmax): Softmax ()
)
[1/10] Loss: 1.78497, Acc: 68.79
[2/10] Loss: 1.54269, Acc: 93.10
[3/10] Loss: 1.52096, Acc: 94.93
[4/10] Loss: 1.51040, Acc: 95.82
[5/10] Loss: 1.50393, Acc: 96.45
[6/10] Loss: 1.49967, Acc: 96.77
[7/10] Loss: 1.49655, Acc: 97.02
[8/10] Loss: 1.49401, Acc: 97.24
[9/10] Loss: 1.49192, Acc: 97.45
[10/10] Loss: 1.49050, Acc: 97.56
Test: Loss: 1.48912, Acc: 97.62 %
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
該工程完整的程式碼我已經放到github上,有興趣的可以去下載試試:
https://github.com/wblgers/stanford_dl_cnn/tree/master/PyTorch
- 1