1. 程式人生 > >[1609.04802] SRGAN中的那些loss

[1609.04802] SRGAN中的那些loss

1. loss

深度神經網路模型的設計中,loss絕對要佔據一席之位。不同的loss形式,對優化的結果,差別很大。啥也不說,先上程式碼。

2. 程式碼

for epoch in range(args.num_epochs):
    for i, data in enumerate(dataloader_train):
        # forward
        x, target = data
        x = Variable(x)
        y_real = Variable(target)
        target_real = Variable(torch.rand(args.batch_size, 1
)*0.5 + 0.7) target_fake = Variable(torch.rand(args.batch_size, 1)*0.3) y_fake = generator(x) # train discriminator discriminator.zero_grad() discriminator_loss = adversarial_criterion(discriminator(y_real), target_real) + \ adversarial_criterion(discriminator(y_fake), target_fake) mean_discriminator_loss += discriminator_loss.data[0
] discriminator_loss.backward() optim_discriminator.step() # train generator generator.zero_grad() features_real = Variable(feature_extractor(y_real).data) features_fake = Variable(feature_extractor(y_fake)) generator_content_loss = content_criterion(y_fake, y_real) + \ content_criterion(features_fake, features_real) * 0.006
mean_generator_content_loss += generator_content_loss.data[0] generator_adversarial_loss = adversarial_criterion(discriminator(y_fake), ones_const) mean_generator_adversarial_loss += generator_adversarial_loss.data[0] generator_total_loss = generator_content_loss + 1e-3*generator_adversarial_loss mean_generator_total_loss += generator_total_loss.data[0] generator_total_loss.backward() optim_generator.step()

典型的pytorch訓練迴圈的程式碼。程式碼中出現了discriminator_lossgenerator_content_lossgenerator_adversarial_loss三種loss,第一個用來訓練判別器,後兩個加起來,訓練生成器。

上面計算loss的程式碼,視覺化框圖如下
這裡寫圖片描述

訓練流程:
- 沿著紅色虛線,計算判別損失,更新判別器引數Dθ
- 沿著粉色虛線,計算產生損失,更新產生器引數Gθ

3. 和不用GAN的區別

如果不用GAN,模型僅僅用一個G網路,產生y_fake,和y_real求得MSEloss,用這個損失更新網路引數。

而GAN的作用,是額外增加一個D網路和2個損失(判別損失和生成判別損失),用一種交替訓練的方式訓練兩個網路。這個模型可以分為3部分:main模組,adversarial模組,和vgg模組。(一般main模組就是adversarial模組裡的G網路)adversarial可以看作是一種訓練技巧,只在訓練階段會用到adversarial模組進行計算,而在推斷階段,僅僅使用G網路(或者說main模組)。

也就是任何一個問題,都可以讓訓練過程“對抗化”。“對抗化”的步驟是:
- 確定main模組(原始問題的解決辦法)
- 把main模組當成GAN中的G網路
- 另外增加一個D網路(二分類網路)
- 在原來更新main模組的loss中,增加“生成對抗損失”(要生成讓判別器無法區分的資料分佈),一起用來更新main模組(也就是GAN中的G網路)
- 用判別損失更新GAN中的D網路