1. 程式人生 > >學習筆記之——基於pytorch的殘差網路(deep residual network)

學習筆記之——基於pytorch的殘差網路(deep residual network)

本博文為本人學習pytorch系列之——residual network。

前面的博文( 學習筆記之——基於深度學習的分類網路)也已經介紹過ResNet了。ResNet是2015年的ImageNet競賽的冠軍,由微軟研究院提出,通過引入residual block能夠成功地訓練高達152層的神經網路。

(paper的連結:https://arxiv.org/pdf/1512.03385.pdf

在不斷加深神經網路的時候,會處出現一個Degradation,即準確率會先上升然後達到飽和,再繼續增加深度則會導致模型的準確率下降(由於梯度消失和梯度爆炸)提出的殘差塊的結構如下圖所示:

程式碼的參考連結為https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/deep_residual_network/main.py#L76-L113

好,下面給出程式碼

import torch
import torch.nn as nn
import torchvision
#It includes the popular data set, model structure and commonly used image conversion tools.
import torchvision.transforms as transforms


#Device configuration
device=torch.device('cuda:0'if torch.cuda.is_available() else 'cpu')

#Hyper parameters
num_epochs=80
#num_classes=10#number 0~9
batch_size=100
learning_rate=0.001

#Image preprocessing modules
#in here we use the data arguement in pytorch
transform=transforms.Compose([
								transforms.Pad(4),#Zero boundary fill for pictures
								transforms.RandomHorizontalFlip(),#Random horizontal flip with probability of 0.5
								transforms.RandomCrop(32),#Random clipping of pictures for a given size.
								transforms.ToTensor()                            
							])


#CIFAR-10 dataset
train_dataset=torchvision.datasets.CIFAR10(root='./CIFAR10_data',train=True,transform=transform,download=True)
test_dataset=torchvision.datasets.CIFAR10(root='./CIFAR10_data',train=False,transform=transforms.ToTensor())

#data loader 
train_loader=torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader=torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)


##########################################################
#residual block
class Residual_Block (nn.Module):
    def __init__(self,i_channel,o_channel,stride=1,downsample=None):
        super(Residual_Block,self).__init__()
        self.conv1=nn.Conv2d(in_channels=i_channel,out_channels=o_channel,kernel_size=3,stride=stride,padding=1,bias=False)
        self.bn1=nn.BatchNorm2d(o_channel)
        self.relu=nn.ReLU(inplace=True)
        
        self.conv2=nn.Conv2d(in_channels=o_channel,out_channels=o_channel,kernel_size=3,stride=1,padding=1,bias=False)
        self.bn2=nn.BatchNorm2d(o_channel)
        self.downsample=downsample
        
    def forward(self,x):
        residual=x
        
        out=self.conv1(x)
        out=self.bn1(out)
        out=self.relu(out)
        out=self.conv2(out)
        out=self.bn2(out)
        
        if self.downsample:
            residual=self.downsample(x)
        
        out+=residual
        out=self.relu(out)
        
        return out

##############################################################
#ResNet
class ResNet(nn.Module):
    def __init__(self,block,layers,num_classes=10):
        super(ResNet,self).__init__()
        self.in_channels=16
        self.conv=nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3,stride=1,padding=1,bias=False)
        self.bn=nn.BatchNorm2d(16)
        self.relu=nn.ReLU(inplace=True)
        self.layer1=self.make_layer(block,16,layers[0])
        self.layer2=self.make_layer(block,32,layers[0],2)
        self.layer3=self.make_layer(block,64,layers[1],2)
        self.avg_pool=nn.AvgPool2d(8)
        self.fc=nn.Linear(64,num_classes)
    
    
    def make_layer(self,block,out_channels,blocks,stride=1):#blocks=layers,the number of residual block
        downsample=None
        if (stride!=1) or (self.in_channels !=out_channels):
            downsample=nn.Sequential(
                                     nn.Conv2d(self.in_channels,out_channels,kernel_size=3,stride=stride,padding=1,bias=False),
                                     nn.BatchNorm2d(out_channels)
                                     )
        layers=[]
        layers.append(block(self.in_channels,out_channels,stride,downsample))
        self.in_channels=out_channels
        for i in range(1,blocks):
            layers.append(block(out_channels,out_channels))
        return nn.Sequential(*layers)# add all of the residual block
            
    
    def forward(self,x):
        out = self.conv(x) 
        out = self.bn(out) 
        out = self.relu(out) 
        out = self.layer1(out) 
        out = self.layer2(out) 
        out = self.layer3(out) 
        out = self.avg_pool(out) 
        out = out.view(out.size(0), -1) 
        out = self.fc(out) 
        
        return out 

model=ResNet(Residual_Block,[2,2,2,2]).to(device)


#loss and optimizer
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)

#for updating learning rate
def update_lr(optimizer,lr):
    for param_group in optimizer.param_groups:
        param_group['lr']=lr



########################################################################
#train the model
total_step=len(train_loader)
curr_lr=learning_rate
for epoch in range(num_epochs):
    for i,(images,labels) in enumerate(train_loader):
        images=images.to(device)
        labels=labels.to(device)
        
        #forward pass
        outputs=model(images)
        loss=criterion(outputs,labels)
        
        #Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1)%100==0:#each 100 iterations is one epoch
            print("Epoch [{}/{}],step[{}/{}] Loss:{:.4f}"
                  .format(epoch+1,num_epochs,i+1,total_step,loss.item()))
            
    #decay the learning rate
    if (epoch+1)%20==0:#each 20 epoch, decay the learning rate
        curr_lr/=3
        update_lr(optimizer,curr_lr)

#test the model
model.eval()
with torch.no_grad():
    correct=0
    total=0
    for images, labels in test_loader:
        images=images.to(device)
        labels=labels.to(device)
        outputs=model(images)
        _,predicted=torch.max(outputs.data,1)
        total+=labels.size(0)
        correct+=(predicted==labels).sum().item()
        
    print('accuracy of the model on the test images: {}%'
          .format(100*correct/total))
    
#save the model checkpoint
#torch.save(model.state_dict(),'resnet.ckpt')

結果(執行過程出了挺多bug的hhh,coding能力還是有待提升哎)

 

關於資料增廣

通過資料增廣可以提高模型的準確率和泛化能力。

firefox崩掉了,可以採用下面指令

ps -ef | grep firefox | awk '{print $2}' | xargs kill -9