[1609.04802] SRGAN中的那些loss
阿新 • • 發佈:2019-01-10
1. loss
深度神經網路模型的設計中,loss絕對要佔據一席之位。不同的loss形式,對優化的結果,差別很大。啥也不說,先上程式碼。
2. 程式碼
for epoch in range(args.num_epochs):
for i, data in enumerate(dataloader_train):
# forward
x, target = data
x = Variable(x)
y_real = Variable(target)
target_real = Variable(torch.rand(args.batch_size, 1 )*0.5 + 0.7)
target_fake = Variable(torch.rand(args.batch_size, 1)*0.3)
y_fake = generator(x)
# train discriminator
discriminator.zero_grad()
discriminator_loss = adversarial_criterion(discriminator(y_real), target_real) + \
adversarial_criterion(discriminator(y_fake), target_fake)
mean_discriminator_loss += discriminator_loss.data[0 ]
discriminator_loss.backward()
optim_discriminator.step()
# train generator
generator.zero_grad()
features_real = Variable(feature_extractor(y_real).data)
features_fake = Variable(feature_extractor(y_fake))
generator_content_loss = content_criterion(y_fake, y_real) + \
content_criterion(features_fake, features_real) * 0.006
mean_generator_content_loss += generator_content_loss.data[0]
generator_adversarial_loss = adversarial_criterion(discriminator(y_fake), ones_const)
mean_generator_adversarial_loss += generator_adversarial_loss.data[0]
generator_total_loss = generator_content_loss + 1e-3*generator_adversarial_loss
mean_generator_total_loss += generator_total_loss.data[0]
generator_total_loss.backward()
optim_generator.step()
典型的pytorch訓練迴圈的程式碼。程式碼中出現了discriminator_loss、generator_content_loss、generator_adversarial_loss三種loss,第一個用來訓練判別器,後兩個加起來,訓練生成器。
上面計算loss的程式碼,視覺化框圖如下
訓練流程:
- 沿著紅色虛線,計算判別損失,更新判別器引數
- 沿著粉色虛線,計算產生損失,更新產生器引數
3. 和不用GAN的區別
如果不用GAN,模型僅僅用一個G網路,產生y_fake,和y_real求得MSEloss,用這個損失更新網路引數。
而GAN的作用,是額外增加一個D網路和2個損失(判別損失和生成判別損失),用一種交替訓練的方式訓練兩個網路。這個模型可以分為3部分:main模組,adversarial模組,和vgg模組。(一般main模組就是adversarial模組裡的G網路)adversarial可以看作是一種訓練技巧,只在訓練階段會用到adversarial模組進行計算,而在推斷階段,僅僅使用G網路(或者說main模組)。
也就是任何一個問題,都可以讓訓練過程“對抗化”。“對抗化”的步驟是:
- 確定main模組(原始問題的解決辦法)
- 把main模組當成GAN中的G網路
- 另外增加一個D網路(二分類網路)
- 在原來更新main模組的loss中,增加“生成對抗損失”(要生成讓判別器無法區分的資料分佈),一起用來更新main模組(也就是GAN中的G網路)
- 用判別損失更新GAN中的D網路