1. 程式人生 > >深度學習分割網路U-net的pytorch模型實現

深度學習分割網路U-net的pytorch模型實現

原文:http://blog.csdn.net/u014722627/article/details/60883185

pytorch是一個很好用的工具,作為一個python的深度學習包,其介面呼叫起來很方便,具備自動求導功能,適合快速實現構思,且程式碼可讀性強,比如前陣子的WGAN1 
好了回到Unet。 
原文 arXiv:1505.04597 [cs.CV] 
主頁 U-Net: Convolutional Networks for Biomedical Image Segmentation 
該文章實現了生物影象分割的一個網路,2015年的模型,好像是該領域的冠軍。模型長得像個巨大的U,故取名Unet,之前很火的動漫線稿自動上色

2就是用的這個模型。當然,該模型也許比不上現在的各種生成式模型了,不過拿來在pytorch裡練練手,當做boundary提取,還是可以的。注意這個網路的輸出size與輸入size不一致,所以應用起來需要額外的處理。 
模型長這個鬼樣: 
unet模型

參考pytorch的tutorial程式碼,實現如下:

#unet.py:
from __future__ import division
import torch.nn as nn
import torch.nn.functional as F
import torch
from numpy.linalg import svd
from numpy.random import
normal from math import sqrt class UNet(nn.Module): def __init__(self,colordim =1): super(UNet, self).__init__() self.conv1_1 = nn.Conv2d(colordim, 64, 3) # input of (n,n,1), output of (n-2,n-2,64) self.conv1_2 = nn.Conv2d(64, 64, 3) self.bn1 = nn.BatchNorm2d(64
) self.conv2_1 = nn.Conv2d(64, 128, 3) self.conv2_2 = nn.Conv2d(128, 128, 3) self.bn2 = nn.BatchNorm2d(128) self.conv3_1 = nn.Conv2d(128, 256, 3) self.conv3_2 = nn.Conv2d(256, 256, 3) self.bn3 = nn.BatchNorm2d(256) self.conv4_1 = nn.Conv2d(256, 512, 3) self.conv4_2 = nn.Conv2d(512, 512, 3) self.bn4 = nn.BatchNorm2d(512) self.conv5_1 = nn.Conv2d(512, 1024, 3) self.conv5_2 = nn.Conv2d(1024, 1024, 3) self.upconv5 = nn.Conv2d(1024, 512, 1) self.bn5 = nn.BatchNorm2d(512) self.bn5_out = nn.BatchNorm2d(1024) self.conv6_1 = nn.Conv2d(1024, 512, 3) self.conv6_2 = nn.Conv2d(512, 512, 3) self.upconv6 = nn.Conv2d(512, 256, 1) self.bn6 = nn.BatchNorm2d(256) self.bn6_out = nn.BatchNorm2d(512) self.conv7_1 = nn.Conv2d(512, 256, 3) self.conv7_2 = nn.Conv2d(256, 256, 3) self.upconv7 = nn.Conv2d(256, 128, 1) self.bn7 = nn.BatchNorm2d(128) self.bn7_out = nn.BatchNorm2d(256) self.conv8_1 = nn.Conv2d(256, 128, 3) self.conv8_2 = nn.Conv2d(128, 128, 3) self.upconv8 = nn.Conv2d(128, 64, 1) self.bn8 = nn.BatchNorm2d(64) self.bn8_out = nn.BatchNorm2d(128) self.conv9_1 = nn.Conv2d(128, 64, 3) self.conv9_2 = nn.Conv2d(64, 64, 3) self.conv9_3 = nn.Conv2d(64, colordim, 1) self.bn9 = nn.BatchNorm2d(colordim) self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False) self.upsample = nn.UpsamplingBilinear2d(scale_factor=2) self._initialize_weights() def forward(self, x1): x1 = F.relu(self.bn1(self.conv1_2(F.relu(self.conv1_1(x1))))) # print('x1 size: %d'%(x1.size(2))) x2 = F.relu(self.bn2(self.conv2_2(F.relu(self.conv2_1(self.maxpool(x1)))))) # print('x2 size: %d'%(x2.size(2))) x3 = F.relu(self.bn3(self.conv3_2(F.relu(self.conv3_1(self.maxpool(x2)))))) # print('x3 size: %d'%(x3.size(2))) x4 = F.relu(self.bn4(self.conv4_2(F.relu(self.conv4_1(self.maxpool(x3)))))) # print('x4 size: %d'%(x4.size(2))) xup = F.relu(self.conv5_2(F.relu(self.conv5_1(self.maxpool(x4))))) # x5 # print('x5 size: %d'%(xup.size(2))) xup = self.bn5(self.upconv5(self.upsample(xup))) # x6in cropidx = (x4.size(2) - xup.size(2)) // 2 x4 = x4[:, :, cropidx:cropidx + xup.size(2), cropidx:cropidx + xup.size(2)] # print('crop1 size: %d, x9 size: %d'%(x4crop.size(2),xup.size(2))) xup = self.bn5_out(torch.cat((x4, xup), 1)) # x6 cat x4 xup = F.relu(self.conv6_2(F.relu(self.conv6_1(xup)))) # x6out xup = self.bn6(self.upconv6(self.upsample(xup))) # x7in cropidx = (x3.size(2) - xup.size(2)) // 2 x3 = x3[:, :, cropidx:cropidx + xup.size(2), cropidx:cropidx + xup.size(2)] # print('crop1 size: %d, x9 size: %d'%(x3crop.size(2),xup.size(2))) xup = self.bn6_out(torch.cat((x3, xup), 1) ) # x7 cat x3 xup = F.relu(self.conv7_2(F.relu(self.conv7_1(xup)))) # x7out xup = self.bn7(self.upconv7(self.upsample(xup)) ) # x8in cropidx = (x2.size(2) - xup.size(2)) // 2 x2 = x2[:, :, cropidx:cropidx + xup.size(2), cropidx:cropidx + xup.size(2)] # print('crop1 size: %d, x9 size: %d'%(x2crop.size(2),xup.size(2))) xup = self.bn7_out(torch.cat((x2, xup), 1)) # x8 cat x2 xup = F.relu(self.conv8_2(F.relu(self.conv8_1(xup)))) # x8out xup = self.bn8(self.upconv8(self.upsample(xup)) ) # x9in cropidx = (x1.size(2) - xup.size(2)) // 2 x1 = x1[:, :, cropidx:cropidx + xup.size(2), cropidx:cropidx + xup.size(2)] # print('crop1 size: %d, x9 size: %d'%(x1crop.size(2),xup.size(2))) xup = self.bn8_out(torch.cat((x1, xup), 1)) # x9 cat x1 xup = F.relu(self.conv9_3(F.relu(self.conv9_2(F.relu(self.conv9_1(xup)))))) # x9out return F.softsign(self.bn9(xup)) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, sqrt(2. / n)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() unet = UNet().cuda()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110

訓練集。。因為沒找到原先的庫,就先用著BSDS500了。。。這裡的BSD500是我上一篇博文處理過的那樣的 
但是由於訓練集很少,可以做隨機中心裁剪和隨機水平翻轉的資料增廣, 注意在torchvision.transforms這個包裡,不支援對多幅輸入影象做相同的裁剪操作,所以把這個增廣的步驟放到train.py了

#BSDDataLoader.py
#這裡主要是想說明pytorch的訓練集load操作,簡直傻瓜式操作!媽媽再也不用擔心我的預處理了!
from os.path import exists, join
from torchvision.transforms import Compose, CenterCrop, ToTensor, Scale
import torch.utils.data as data
from os import listdir
from PIL import Image


def bsd500(dest="/dir/to/dataset"):#自行修改路徑!!

    if not exists(dest):
        print("dataset not exist ")
    return dest


def input_transform(crop_size):
    return Compose([
        CenterCrop(crop_size),
        ToTensor()
    ])


def get_training_set(size, target_mode='seg', colordim=1):
    root_dir = bsd500()
    train_dir = join(root_dir, "train")
    return DatasetFromFolder(train_dir,target_mode,colordim,
                             input_transform=input_transform(size),
                             target_transform=input_transform(size))


def get_test_set(size, target_mode='seg', colordim=1):
    root_dir = bsd500()
    test_dir = join(root_dir, "test")
    return DatasetFromFolder(test_dir,target_mode,colordim,
                             input_transform=input_transform(size),
                             target_transform=input_transform(size))




def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])


def load_img(filepath,colordim):
    if colordim==1:
        img = Image.open(filepath).convert('L')
    else:
        img = Image.open(filepath).convert('RGB')
    #y, _, _ = img.split()
    return img


class DatasetFromFolder(data.Dataset):
    def __init__(self, image_dir, target_mode, colordim, input_transform=None, target_transform=None):
        super(DatasetFromFolder, self).__init__()
        self.image_filenames = [x for x in listdir( join(image_dir,'data') ) if is_image_file(x)]
        self.input_transform = input_transform
        self.target_transform = target_transform
        self.image_dir = image_dir
        self.target_mode = target_mode
        self.colordim = colordim

    def __getitem__(self, index):


        input = load_img(join(self.image_dir,'data',self.image_filenames[index]),self.colordim)
        if self.target_mode=='seg':
            target = load_img(join(self.image_dir,'seg',self.image_filenames[index]),1)
        else:
            target = load_img(join(self.image_dir,'bon',self.image_filenames[index]),1)


        if self.input_transform:
            input = self.input_transform(input)
        if self.target_transform:
            target = self.target_transform(target)

        return input, target

    def __len__(self):
        return len(self.image_filenames)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
#train.py
'''
因為原文中網路的input和output size不一樣,不知道他是怎麼搞的loss
簡單起見,我就將groundtruth中心crop到和output一樣大,然後求MSE loss了
結果還是收斂的,做過增廣的資料用於訓練,得到的測試集loss要大一點,因為訓練時的尺度不一樣,估計影響了泛化效果
'''
from __future__ import print_function
from math import log10
import numpy as np
import random
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from unet import UNet
from BSDDataLoader import get_training_set,get_test_set
import torchvision


# Training settings
class option:
    def __init__(self):
        self.cuda = True #use cuda?
        self.batchSize = 4 #training batch size
        self.testBatchSize = 4 #testing batch size
        self.nEpochs = 140 #umber of epochs to train for
        self.lr = 0.001 #Learning Rate. Default=0.01
        self.threads = 4 #number of threads for data loader to use
        self.seed = 123 #random seed to use. Default=123
        self.size = 428
        self.remsize = 20
        self.colordim = 1
        self.target_mode = 'bon'
        self.pretrain_net = "/home/wcd/PytorchProject/Unet/unetdata/checkpoint/model_epoch_140.pth"

def map01(tensor,eps=1e-5):
    #input/output:tensor
    max = np.max(tensor.numpy(), axis=(1,2,3), keepdims=True)
    min = np.min(tensor.numpy(), axis=(1,2,3), keepdims=True)
    if (max-min).any():
        return torch.from_numpy( (tensor.numpy() - min) / (max-min + eps) )
    else:
        return torch.from_numpy( (tensor.numpy() - min) / (max-min) )


def sizeIsValid(size):
    for i in range(4):
        size -= 4
        if size%2:
            return 0
        else:
            size /= 2
    for i in range(4):
        size -= 4
        size *= 2
    return size-4



opt = option()
target_size = sizeIsValid(opt.size)
print("outputsize is: "+str(target_size))
if not target_size:
    raise  Exception("input size invalid")
target_gap = (opt.size - target_size)//2
cuda = opt.cuda
if cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")

torch.manual_seed(opt.seed)
if cuda:
    torch.cuda.manual_seed(opt.seed)

print('===> Loading datasets')
train_set = get_training_set(opt.size + opt.remsize, target_mode=opt.target_mode, colordim=opt.colordim)
test_set = get_test_set(opt.size, target_mode=opt.target_mode, colordim=opt.colordim)
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)

print('===> Building unet')
unet = UNet(opt.colordim)


criterion = nn.MSELoss()
if cuda:
    unet = unet.cuda()
    criterion = criterion.cuda()

pretrained = True
if pretrained:
    unet.load_state_dict(torch.load(opt.pretrain_net))

optimizer = optim.SGD(unet.parameters(), lr=opt.lr)
print('===> Training unet')

def train(epoch):
    epoch_loss = 0
    for iteration, batch in enumerate(training_data_loader, 1):
        randH = random.randint(0, opt.remsize)
        randW = random.randint(0, opt.remsize)
        input = Variable(batch[0][:, :, randH:randH + opt.size, randW:randW + opt.size])
        target = Variable(batch[1][:, :,
                         randH + target_gap:randH + target_gap + target_size,
                         randW + target_gap:randW + target_gap + target_size])
        #target =target.squeeze(1)
        #print(target.data.size())
        if cuda:
            input = input.cuda()
            target = target.cuda()
        input = unet(input)
        #print(input.data.size())
        loss = criterion( input, target)
        epoch_loss += loss.data[0]
        loss.backward()
        optimizer.step()
        if iteration%10 is 0:
            print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss.data[0]))
    imgout = input.data/2 +1
    torchvision.utils.save_image(imgout,"/home/wcd/PytorchProject/Unet/unetdata/checkpoint/epch_"+str(epoch)+'.jpg')
    print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(training_data_loader)))


def test():
    totalloss = 0
    for batch in testing_data_loader:
        input = Variable(batch[0],volatile=True)
        target = Variable(batch[1][:, :,
                          target_gap:target_gap + target_size,
                          target_gap:target_gap + target_size],
                          volatile=True)
        #target =target.long().squeeze(1)
        if cuda:
            input = input.cuda()
            target = target.cuda()
        optimizer.zero_grad()
        prediction = unet(input)
        loss = criterion(prediction, target)
        totalloss += loss.data[0]
    print("===> Avg. test loss: {:.4f} dB".format(totalloss / len(testing_data_loader)))


def checkpoint(epoch):
    model_out_path = "/home/wcd/PytorchProject/Unet/unetdata/checkpoint/model_epoch_{}.pth".format(epoch)
    torch.save(unet.state_dict(), model_out_path)
    print("Checkpoint saved to {}".format(model_out_path))

for epoch in range(141, 141+opt.nEpochs + 1):
    train(epoch)
    if epoch%10 is 0:
        checkpoint(epoch)
    test()
checkpoint(epoch)


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157

如果想要看看網路的結構 還可以這樣

from graphviz import Digraph
from torch.autograd import Variable
from unet import UNet

def make_dot