1. 程式人生 > >用pytorch實現多層感知機(MLP)(全連線神經網路FC)分類MNIST手寫數字體的識別

用pytorch實現多層感知機(MLP)(全連線神經網路FC)分類MNIST手寫數字體的識別

1.匯入必備的包

1 import torch 
2 import numpy as np 
3 from torchvision.datasets import mnist
4 from torch import nn
5 from torch.autograd import Variable
6 import matplotlib.pyplot as plt
7 import torch.nn.functional as F
8 from torch.utils.data import DataLoader
9 %matplotlib inline

2.定義mnist資料的格式變換

1 def data_transform(x):
2     x = np.array(x, dtype = 'float32') / 255
3     x = (x - 0.5) /0.5
4     x = x.reshape((-1, ))
5     x = torch.from_numpy(x)
6     return x

3.下載資料集,定義資料迭代器

1 trainset = mnist.MNIST('./dataset/mnist', train=True, transform=data_transform, download=True)
2 testset = mnist.MNIST('
./dataset/mnist', train = False, transform=data_transform, download=True)】 3 train_data = DataLoader(trainset, batch_size=64, shuffle=True) 4 test_data = DataLoader(testset, batch_size=128, shuffle=False)

4.定義全連線神經網路(多層感知機)

 1 class MLP(nn.Module):
 2     def __init__(self):
 3         super(MLP, self).__init__
() 4 self.fc1 = nn.Linear(28*28, 500) 5 self.fc2 = nn.Linear(500, 250) 6 self.fc3 = nn.Linear(250, 125) 7 self.fc4 = nn.Linear(125, 10) 8 9 def forward(self, x): 10 x = F.relu(self.fc1(x)) 11 x = F.relu(self.fc2(x)) 12 x = F.relu(self.fc3(x)) 13 x = self.fc4(x) 14 return x 15 16 mlp = MLP()

5.定義損失函式和優化器

1 criterion = nn.CrossEntropyLoss()
2 optimizer = torch.optim.SGD(mlp.parameters(), 1e-3)

6.開始訓練和測試

 1 losses = []
 2 acces = []
 3 eval_losses = []
 4 eval_acces = []
 5 
 6 for e in range(20):
 7     train_loss = 0
 8     train_acc = 0
 9     mlp.train()
10     for im, label in train_data:
11         im = Variable(im)
12         label = Variable(label)
13         # 前向傳播
14         out = mlp(im)
15         loss = criterion(out, label)
16         # 反向傳播
17         optimizer.zero_grad()
18         loss.backward()
19         optimizer.step()
20         # 記錄誤差
21         train_loss += loss.item()
22         # 計算分類的準確率
23         _, pred = out.max(1)
24         num_correct = (pred == label).sum().item()
25         acc = num_correct / im.shape[0]
26         train_acc += acc
27         
28     losses.append(train_loss / len(train_data))
29     acces.append(train_acc / len(train_data))
30     # 在測試集上檢驗效果
31     eval_loss = 0
32     eval_acc = 0
33     mlp.eval() # 將模型改為預測模式
34     for im, label in test_data:
35         im = Variable(im)
36         label = Variable(label)
37         out = mlp(im)
38         loss = criterion(out, label)
39         # 記錄誤差
40         eval_loss += loss.item()
41         # 記錄準確率
42         _, pred = out.max(1)
43         num_correct = (pred == label).sum().item()
44         acc = num_correct / im.shape[0]
45         eval_acc += acc
46         
47     eval_losses.append(eval_loss / len(test_data))
48     eval_acces.append(eval_acc / len(test_data))
49     print('epoch: {}, Train Loss: {:.6f}, Train Acc: {:.6f}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
50           .format(e, train_loss / len(train_data), train_acc / len(train_data), 
51                      eval_loss / len(test_data), eval_acc / len(test_data)))

7.測試結果

epoch: 0, Train Loss: 2.287240, Train Acc: 0.124150, Eval Loss: 2.265074, Eval Acc: 0.237540
epoch: 1, Train Loss: 2.237043, Train Acc: 0.385861, Eval Loss: 2.197773, Eval Acc: 0.524921
epoch: 2, Train Loss: 2.138911, Train Acc: 0.555487, Eval Loss: 2.050214, Eval Acc: 0.554292
epoch: 3, Train Loss: 1.901877, Train Acc: 0.563833, Eval Loss: 1.688784, Eval Acc: 0.592662
epoch: 4, Train Loss: 1.439467, Train Acc: 0.625483, Eval Loss: 1.178063, Eval Acc: 0.704905
epoch: 5, Train Loss: 1.022494, Train Acc: 0.745586, Eval Loss: 0.869467, Eval Acc: 0.778184
epoch: 6, Train Loss: 0.795575, Train Acc: 0.790528, Eval Loss: 0.702586, Eval Acc: 0.808347
epoch: 7, Train Loss: 0.665018, Train Acc: 0.816031, Eval Loss: 0.601074, Eval Acc: 0.831586
epoch: 8, Train Loss: 0.583082, Train Acc: 0.834588, Eval Loss: 0.535897, Eval Acc: 0.843750
epoch: 9, Train Loss: 0.527930, Train Acc: 0.848231, Eval Loss: 0.490443, Eval Acc: 0.857694
epoch: 10, Train Loss: 0.488764, Train Acc: 0.858925, Eval Loss: 0.456138, Eval Acc: 0.866396
epoch: 11, Train Loss: 0.459293, Train Acc: 0.868220, Eval Loss: 0.430784, Eval Acc: 0.873220
epoch: 12, Train Loss: 0.436398, Train Acc: 0.874117, Eval Loss: 0.413343, Eval Acc: 0.875890
epoch: 13, Train Loss: 0.418043, Train Acc: 0.880031, Eval Loss: 0.396967, Eval Acc: 0.880340
epoch: 14, Train Loss: 0.403195, Train Acc: 0.884029, Eval Loss: 0.385431, Eval Acc: 0.885483
epoch: 15, Train Loss: 0.390613, Train Acc: 0.887327, Eval Loss: 0.372552, Eval Acc: 0.889537
epoch: 16, Train Loss: 0.379947, Train Acc: 0.890275, Eval Loss: 0.363168, Eval Acc: 0.891812
epoch: 17, Train Loss: 0.370701, Train Acc: 0.893557, Eval Loss: 0.355597, Eval Acc: 0.894482
epoch: 18, Train Loss: 0.362498, Train Acc: 0.896572, Eval Loss: 0.348329, Eval Acc: 0.897844
epoch: 19, Train Loss: 0.354748, Train Acc: 0.898121, Eval Loss: 0.340272, Eval Acc: 0.899921

8.訓練損失和訓練精度曲線

1 plt.title('train loss')
2 plt.plot(np.arange(len(losses)), losses)

1 plt.plot(np.arange(len(acces)), acces)
2 plt.title('train acc')