用pytorch實現GAN——mnist(含有全部註釋和網路思想)
阿新 • • 發佈:2018-12-28
#coding=utf-8 import torch.autograd import torch.nn as nn from torch.autograd import Variable from torchvision import transforms from torchvision import datasets from torchvision.utils import save_image import os #建立資料夾 if not os.path.exists('./img'): os.mkdir('./img') def to_img(x): out=0.5*(x+1) out=out.clamp(0,1)#Clamp函式可以將隨機變化的數值限制在一個給定的區間[min, max]內: out=out.view(-1,1,28,28)#view()函式作用是將一個多行的Tensor,拼接成一行 return out batch_size=128 num_epoch=100 z_dimension=100 #圖形啊處理過程 img_transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5)) ]) #mnist dataset mnist資料集下載 mnist=datasets.MNIST( root='./data/',train=True,transform=img_transform,download=True ) #data loader 資料載入 dataloader=torch.utils.data.DataLoader( dataset=mnist,batch_size=batch_size,shuffle=True ) #定義判別器 #####Discriminator######使用多層網路來作為判別器 #將圖片28x28展開成784,然後通過多層感知器,中間經過斜率設定為0.2的LeakyReLU啟用函式, # 最後接sigmoid啟用函式得到一個0到1之間的概率進行二分類。 class discriminator(nn.Module): def __init__(self): super(discriminator,self).__init__() self.dis=nn.Sequential( nn.Linear(784,256),#輸入特徵數為784,輸出為256 nn.LeakyReLU(0.2),#進行非線性對映 nn.Linear(256,256),#進行一個線性對映 nn.LeakyReLU(0.2), nn.Linear(256,1), nn.Sigmoid()#也是一個啟用函式,二分類問題中, # sigmoid可以班實數對映到【0,1】,作為概率值, # 多分類用softmax函式 ) def forward(self, x): x=self.dis(x) return x ####### 定義生成器 Generator ##### #輸入一個100維的0~1之間的高斯分佈,然後通過第一層線性變換將其對映到256維, # 然後通過LeakyReLU啟用函式,接著進行一個線性變換,再經過一個LeakyReLU啟用函式, # 然後經過線性變換將其變成784維,最後經過Tanh啟用函式是希望生成的假的圖片資料分佈 # 能夠在-1~1之間。 class generator(nn.Module): def __init__(self): super(generator,self).__init__() self.gen=nn.Sequential( nn.Linear(100,256),#用線性變換將輸入對映到256維 nn.ReLU(True),#relu啟用 nn.Linear(256,256),#線性變換 nn.ReLU(True),#relu啟用 nn.Linear(256,784),#線性變換 nn.Tanh()#Tanh啟用使得生成資料分佈在【-1,1】之間 ) def forward(self, x): x=self.gen(x) return x #建立物件 D=discriminator() G=generator() if torch.cuda.is_available(): D=D.cuda() G=G.cuda() #########判別器訓練train##################### #分為兩部分:1、真的影象判別為真;2、假的影象判別為假 #此過程中,生成器引數不斷更新 #首先需要定義loss的度量方式 (二分類的交叉熵) #其次定義 優化函式,優化函式的學習率為0.0003 criterion = nn.BCELoss() #是單目標二分類交叉熵函式 d_optimizer=torch.optim.Adam(D.parameters(),lr=0.0003) g_optimizer=torch.optim.Adam(G.parameters(),lr=0.0003) ###########################進入訓練##判別器的判斷過程##################### for epoch in range(num_epoch): #進行多個epoch的訓練 for i,(img, _) in enumerate(dataloader): num_img=img.size(0) # view()函式作用是將一個多行的Tensor,拼接成一行 # 第一個引數是要拼接的tensor,第二個引數是-1 # =============================訓練判別器================== img = img.view(num_img, -1) # 將圖片展開為28*28=784 real_img = Variable(img).cuda() # 將tensor變成Variable放入計算圖中 real_label = Variable(torch.ones(num_img)).cuda() # 定義真實的圖片label為1 fake_label = Variable(torch.zeros(num_img)).cuda() # 定義假的圖片的label為0 # 計算真實圖片的損失 real_out = D(real_img) # 將真實圖片放入判別器中 d_loss_real = criterion(real_out, real_label) # 得到真實圖片的loss real_scores = real_out # 得到真實圖片的判別值,輸出的值越接近1越好 # 計算假的圖片的損失 z = Variable(torch.randn(num_img, z_dimension)).cuda() # 隨機生成一些噪聲 fake_img = G(z) # 隨機噪聲放入生成網路中,生成一張假的圖片 fake_out = D(fake_img) # 判別器判斷假的圖片 d_loss_fake = criterion(fake_out, fake_label) # 得到假的圖片的loss fake_scores = fake_out # 得到假圖片的判別值,對於判別器來說,假圖片的損失越接近0越好 # 損失函式和優化 d_loss = d_loss_real + d_loss_fake #損失包括判真損失和判假損失 d_optimizer.zero_grad() # 在反向傳播之前,先將梯度歸0 d_loss.backward() # 將誤差反向傳播 d_optimizer.step() # 更新引數 # ==================訓練生成器============================ ################################生成網路的訓練############################### # 原理:目的是希望生成的假的圖片被判別器判斷為真的圖片, # 在此過程中,將判別器固定,將假的圖片傳入判別器的結果與真實的label對應, # 反向傳播更新的引數是生成網路裡面的引數, # 這樣可以通過更新生成網路裡面的引數,來訓練網路,使得生成的圖片讓判別器以為是真的 # 這樣就達到了對抗的目的 # 計算假的圖片的損失 z = Variable(torch.randn(num_img, z_dimension)).cuda() # 得到隨機噪聲 fake_img = G(z) #隨機噪聲輸入到生成器中,得到一副假的圖片 output = D(fake_img) # 經過判別器得到的結果 g_loss = criterion(output, real_label) # 得到的假的圖片與真實的圖片的label的loss # bp and optimize g_optimizer.zero_grad() # 梯度歸0 g_loss.backward() # 進行反向傳播 g_optimizer.step() # .step()一般用在反向傳播後面,用於更新生成網路的引數 #列印中間的損失 if (i+1)%100==0: print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} ' 'D real: {:.6f},D fake: {:.6f}'.format( epoch,num_epoch,d_loss.data[0],g_loss.data[0], real_scores.data.mean(),fake_scores.data.mean() #列印的是真實圖片的損失均值 )) if epoch==0: real_images=to_img(real_img.cpu().data) save_image(real_images, './img/real_images.png') fake_images = to_img(real_img.cpu().data) save_image(fake_images, './img/fake_images-{}.png'.format(epoch+1)) #儲存模型 torch.save(G.state_dict(),'./generator.pth') torch.save(D.state_dict(),'./discriminator.pth')