1. 程式人生 > >基於Pytorch的cifar10分類網路模型

基於Pytorch的cifar10分類網路模型

       Pytorch作為新興的深度學習框架,目前的使用率正在逐步上升。相比TensorFlow,Pytorch的上手難度更低,同時Pytorch支援對圖的動態定義,並且能夠方便的將網路中的tensor格式資料與numpy格式資料進行轉換,使得其對某些特殊結構的網路定義起來更加方便,但是Pytorch對於分散式訓練之類的支援相對較差,同時沒有Tensorboard之類的工具對網路進行方便的視覺化。當然,Tensorflow能夠選擇Keras之類的框架,來大幅簡化網路架設工作。

       Pytorch擁有一個不錯的官方教程

https://pytorch.org/tutorials/,包含了從基本運算到影象分類、語義識別、增強學習和今年大火的GAN等案例,解釋的也非常清楚。這裡主要依據官網的這篇教程https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html,並且對網路結構了一些改進,來練習Pytorch的使用。

       這裡也按照官網的步驟來,首先是通過torchvision庫匯入cifar10資料集:

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

       檢視圖片之類的操作就看官網教程好了,這裡省略掉。

       第二步是定義卷積神經網路,官網使用的作為示例的網路如下:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

       這個網路是卷積+全連線層的形式,這種結構的網路效果其實不好,因為全連線層傳遞效率較低,同時會干擾到卷積層提取出的區域性特徵,並且也沒有用到BatchNorm和Dropout來防止過擬合的問題。現在流行的網路結構大多采用全卷積層的結構,下面的結構效果會好很多:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding = 1)
        self.conv2 = nn.Conv2d(64, 64, 3, padding = 1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding = 1)
        self.conv4 = nn.Conv2d(128, 128, 3, padding = 1)
        self.conv5 = nn.Conv2d(128, 256, 3, padding = 1)
        self.conv6 = nn.Conv2d(256, 256, 3, padding = 1)
        self.maxpool = nn.MaxPool2d(2, 2)
        self.avgpool = nn.AvgPool2d(2, 2)
        self.globalavgpool = nn.AvgPool2d(8, 8)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(256)
        self.dropout50 = nn.Dropout(0.5)
        self.dropout10 = nn.Dropout(0.1)
        self.fc = nn.Linear(256, 10)

    def forward(self, x):
        x = self.bn1(F.relu(self.conv1(x)))
        x = self.bn1(F.relu(self.conv2(x)))
        x = self.maxpool(x)
        x = self.dropout10(x)
        x = self.bn2(F.relu(self.conv3(x)))
        x = self.bn2(F.relu(self.conv4(x)))
        x = self.avgpool(x)
        x = self.dropout10(x)
        x = self.bn3(F.relu(self.conv5(x)))
        x = self.bn3(F.relu(self.conv6(x)))
        x = self.globalavgpool(x)
        x = self.dropout50(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

net = Net()

       Pytorch也可以用nn.Sequential函式來很簡單的定義序列網路,和keras的Sequential差不多,但是pytorch需要給出每一層網路的輸入與輸出引數,這一點就不像keras那麼無腦。由於pytorch不像keras自帶GlobalAveragePooling,手寫一個怕自己忘記,其實這裡不加效果會更好,畢竟這相當於強行壓縮特徵資料之後再進行分類。

       這裡再舉個nn.Sequential的基本栗子:

channel_1 = 32
channel_2 = 16
model = nn.Sequential(
    nn.Conv2d(3, channel_1, 5, padding = 2),
    nn.ReLU(),
    nn.Conv2d(channel_1, channel_2, 3, padding = 1),
    nn.ReLU(),
    Flatten(),
    nn.Linear(channel_2 * 32 * 32, 10),
)

       第三步是定義損失函式和優化器,官網這裡用的是帶動量項的SGD,但是個人感覺Adam對複雜函式的優化效果會比SGD好,所以這裡用Adam來代替:

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

       接下來,第四步就是訓練網路了。可以首先使用下列語句來自動判斷使用GPU還是CPU進行計算,不過一般而言,GPU和同檔次的CPU計算速度可以差到50~70倍……

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

       下面就可以開始訓練了,這裡要注意訓練的資料也要.to(device):

for epoch in range(10):

    running_loss = 0.
    batch_size = 100
    
    for i, data in enumerate(
            torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                        shuffle=True, num_workers=2), 0):
        
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        print('[%d, %5d] loss: %.4f' %(epoch + 1, (i+1)*batch_size, loss.item()))

print('Finished Training')

       當然,Pytorch不如Keras直接一個.fit來的方便,但是也不算麻煩。由於不用在一個session裡邊進行計算,靈活性還是比tensorflow和封裝的嚴嚴實實的keras高很多。

       之後,可以用下面的語句儲存或讀取儲存好的模型:

torch.save(net, 'cifar10.pkl')
net = torch.load('cifar10.pkl')

       在訓練完成之後,就可以用測試集檢視訓練結果:

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

       這個網路模型在batch_size=100的條件下訓練10個epoch之後,測試集正確率大概在80%左右,對cifar10資料集而言還算可以啦。

       原始碼放在github上,歡迎取用~地址:https://github.com/PolarisShi/cifar10