1. 程式人生 > >用pytorch實現GAN——mnist(含有全部註釋和網路思想)

用pytorch實現GAN——mnist(含有全部註釋和網路思想)

#coding=utf-8
import torch.autograd
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import  save_image
import os

#建立資料夾
if not os.path.exists('./img'):
    os.mkdir('./img')

def to_img(x):
    out=0.5*(x+1)
    out=out.clamp(0,1)#Clamp函式可以將隨機變化的數值限制在一個給定的區間[min, max]內:
    out=out.view(-1,1,28,28)#view()函式作用是將一個多行的Tensor,拼接成一行
    return out

batch_size=128
num_epoch=100
z_dimension=100

#圖形啊處理過程
img_transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))
])

#mnist dataset mnist資料集下載
mnist=datasets.MNIST(
    root='./data/',train=True,transform=img_transform,download=True
)

#data loader 資料載入
dataloader=torch.utils.data.DataLoader(
    dataset=mnist,batch_size=batch_size,shuffle=True
)


#定義判別器  #####Discriminator######使用多層網路來作為判別器

#將圖片28x28展開成784,然後通過多層感知器,中間經過斜率設定為0.2的LeakyReLU啟用函式,
# 最後接sigmoid啟用函式得到一個0到1之間的概率進行二分類。
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator,self).__init__()
        self.dis=nn.Sequential(
            nn.Linear(784,256),#輸入特徵數為784,輸出為256
            nn.LeakyReLU(0.2),#進行非線性對映
            nn.Linear(256,256),#進行一個線性對映
            nn.LeakyReLU(0.2),
            nn.Linear(256,1),
            nn.Sigmoid()#也是一個啟用函式,二分類問題中,
            # sigmoid可以班實數對映到【0,1】,作為概率值,
            # 多分類用softmax函式
        )
    def forward(self, x):
        x=self.dis(x)
        return x


####### 定義生成器 Generator #####
#輸入一個100維的0~1之間的高斯分佈,然後通過第一層線性變換將其對映到256維,
# 然後通過LeakyReLU啟用函式,接著進行一個線性變換,再經過一個LeakyReLU啟用函式,
# 然後經過線性變換將其變成784維,最後經過Tanh啟用函式是希望生成的假的圖片資料分佈
# 能夠在-1~1之間。
class generator(nn.Module):
    def __init__(self):
        super(generator,self).__init__()
        self.gen=nn.Sequential(
            nn.Linear(100,256),#用線性變換將輸入對映到256維
            nn.ReLU(True),#relu啟用
            nn.Linear(256,256),#線性變換
            nn.ReLU(True),#relu啟用
            nn.Linear(256,784),#線性變換
            nn.Tanh()#Tanh啟用使得生成資料分佈在【-1,1】之間
        )

    def forward(self, x):
        x=self.gen(x)
        return x

#建立物件
D=discriminator()
G=generator()
if torch.cuda.is_available():
    D=D.cuda()
    G=G.cuda()



#########判別器訓練train#####################
#分為兩部分:1、真的影象判別為真;2、假的影象判別為假
#此過程中,生成器引數不斷更新

#首先需要定義loss的度量方式  (二分類的交叉熵)
#其次定義 優化函式,優化函式的學習率為0.0003
criterion = nn.BCELoss() #是單目標二分類交叉熵函式
d_optimizer=torch.optim.Adam(D.parameters(),lr=0.0003)
g_optimizer=torch.optim.Adam(G.parameters(),lr=0.0003)

###########################進入訓練##判別器的判斷過程#####################

for epoch in range(num_epoch): #進行多個epoch的訓練
    for i,(img, _) in enumerate(dataloader):
        num_img=img.size(0)
        # view()函式作用是將一個多行的Tensor,拼接成一行
        # 第一個引數是要拼接的tensor,第二個引數是-1
        # =============================訓練判別器==================
        img = img.view(num_img, -1)  # 將圖片展開為28*28=784
        real_img = Variable(img).cuda()  # 將tensor變成Variable放入計算圖中
        real_label = Variable(torch.ones(num_img)).cuda()  # 定義真實的圖片label為1
        fake_label = Variable(torch.zeros(num_img)).cuda()  # 定義假的圖片的label為0

        # 計算真實圖片的損失
        real_out = D(real_img)  # 將真實圖片放入判別器中
        d_loss_real = criterion(real_out, real_label)  # 得到真實圖片的loss
        real_scores = real_out  # 得到真實圖片的判別值,輸出的值越接近1越好

        # 計算假的圖片的損失
        z = Variable(torch.randn(num_img, z_dimension)).cuda()  # 隨機生成一些噪聲
        fake_img = G(z)  # 隨機噪聲放入生成網路中,生成一張假的圖片
        fake_out = D(fake_img)  # 判別器判斷假的圖片
        d_loss_fake = criterion(fake_out, fake_label)  # 得到假的圖片的loss
        fake_scores = fake_out  # 得到假圖片的判別值,對於判別器來說,假圖片的損失越接近0越好

        # 損失函式和優化
        d_loss = d_loss_real + d_loss_fake #損失包括判真損失和判假損失
        d_optimizer.zero_grad()  # 在反向傳播之前,先將梯度歸0
        d_loss.backward()  # 將誤差反向傳播
        d_optimizer.step()  # 更新引數

        # ==================訓練生成器============================
        ################################生成網路的訓練###############################
        # 原理:目的是希望生成的假的圖片被判別器判斷為真的圖片,
        # 在此過程中,將判別器固定,將假的圖片傳入判別器的結果與真實的label對應,
        # 反向傳播更新的引數是生成網路裡面的引數,
        # 這樣可以通過更新生成網路裡面的引數,來訓練網路,使得生成的圖片讓判別器以為是真的
        # 這樣就達到了對抗的目的

        # 計算假的圖片的損失

        z = Variable(torch.randn(num_img, z_dimension)).cuda()  # 得到隨機噪聲
        fake_img = G(z) #隨機噪聲輸入到生成器中,得到一副假的圖片
        output = D(fake_img)  # 經過判別器得到的結果
        g_loss = criterion(output, real_label)  # 得到的假的圖片與真實的圖片的label的loss

        # bp and optimize
        g_optimizer.zero_grad()  # 梯度歸0
        g_loss.backward()  # 進行反向傳播
        g_optimizer.step()  # .step()一般用在反向傳播後面,用於更新生成網路的引數

        #列印中間的損失
        if (i+1)%100==0:
            print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} '
                  'D real: {:.6f},D fake: {:.6f}'.format(
                epoch,num_epoch,d_loss.data[0],g_loss.data[0],
                real_scores.data.mean(),fake_scores.data.mean()  #列印的是真實圖片的損失均值
            ))

        if epoch==0:
            real_images=to_img(real_img.cpu().data)
            save_image(real_images, './img/real_images.png')

        fake_images = to_img(real_img.cpu().data)
        save_image(fake_images, './img/fake_images-{}.png'.format(epoch+1))
#儲存模型
torch.save(G.state_dict(),'./generator.pth')
torch.save(D.state_dict(),'./discriminator.pth')