1. 程式人生 > >Pytorch入門學習(九)---detach()的作用(從GAN程式碼分析)

Pytorch入門學習(九)---detach()的作用(從GAN程式碼分析)

(八)還沒寫,先跳過。。。

總說

簡單來說detach就是截斷反向傳播的梯度流

    def detach(self):
        """Returns a new Variable, detached from the current graph.

        Result will never require gradient. If the input is volatile, the output
        will be volatile too.

        .. note::

          Returned Variable uses the same data tensor, as the original one, and
          in-place modifications on either of them will be seen, and may trigger
          errors in correctness checks.
        """
result = NoGrad()(self) # this is needed, because it merges version counters result._grad_fn = None return result

可以看到Returns a new Variable, detached from the current graph。將某個node變成不需要梯度的Varibale。因此當反向傳播經過這個node時,梯度就不會從這個node往前面傳播。

從GAN的程式碼中看detach()

GAN的G的更新,主要是GAN loss。就是G生成的fake圖讓D來判別,得到的損失,計算梯度進行反傳。這個梯度只能影響G,不能影響D!

可以看到,由於torch是非自動求導的,每一層的梯度的計算必須用net:backward才能計算gradInput和網路中的引數的梯度。

先看Torch版本的程式碼

local fGx = function(x)
    netD:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)
    netG:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)

    gradParametersG:zero()

    -- GAN loss
local df_dg = torch.zeros(fake_B:size()) if opt.use_GAN==1 then local output = netD.output -- netD:forward{input_A,input_B} was already executed in fDx, so save computation local label = torch.FloatTensor(output:size()):fill(real_label) -- fake labels are real for generator cost errG = criterion:forward(output, label) local df_do = criterion:backward(output, label) df_dg = netD:updateGradInput(fake_AB, df_do):narrow(2,fake_AB:size(2)-output_nc+1, output_nc) else errG = 0 end -- unary loss -- 得到 df_do_AE(已省略) netG:backward(real_A, df_dg + df_do_AE:mul(opt.lambda)) return errG, gradParametersG end

在下面程式碼中,是先得到fake圖進入D的loss,然後這個loss的梯度df_do進行反傳,首先要這個梯度經過D。此時不能改變D的引數的梯度,所以這裡用updateGradInput,不能用backward。這是因為backward是呼叫2個函式updateGradInputaccGradParameters。後者是計算loss對於網路中引數的梯度,這些梯度是不斷累加的!除非手動gradParametersG:zero()置零。

       errG = criterion:forward(output, label)
       local df_do = criterion:backward(output, label)
       df_dg = netD:updateGradInput(fake_AB, df_do):narrow(2,fake_AB:size(2)-output_nc+1, output_nc)
       -- unary loss
       -- 得到 df_do_AE(已省略)   
       netG:backward(real_A, df_dg + df_do_AE:mul(opt.lambda))

然後得到的df_dg才是要更新G的GAN損失的梯度,當然G的另一個損失是L1損失(unary loss)這個沒啥好說了。

pytorch的GAN實現

由於Pytorch是自動反向傳播,

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_B
        # fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
        self.pred_fake = self.netD.forward(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(self.pred_fake, False)

        # Real
        real_AB = self.real_B # GroundTruth
        # real_AB = torch.cat((self.real_A, self.real_B), 1)
        self.pred_real = self.netD.forward(real_AB)
        self.loss_D_real = self.criterionGAN(self.pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = self.fake_B
        pred_fake = self.netD.forward(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()


    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B)

    # 先呼叫 forward, 再 D backward, 更新D之後; 再G backward, 再更新G
    def optimize_parameters(self):
        self.forward()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

解釋backward_D:

對於D,我們值需要,如果輸入是真實圖,那麼產生loss,輸入真實圖,也產生loss。
這兩個梯度進行更新D。如果是真實圖(real_B),由於real_B是初始結點,所以沒什麼可擔心的。但是對於生成圖fake_B,由於 fake_B是由 netG.forward(real_A)產生的。我們只希望 該loss更新D不要影響到 G. 因此這裡需要“截斷反傳的梯度流”,用 fake_AB = fake_B, fake_AB.detach()從而讓梯度不要通過 fake_AB反傳到netG中!

解釋backward_G:

由於在呼叫 backward_G已經呼叫了zero_grad,所以沒什麼好擔心的。
更新G時,來自D的GAN損失是, netD.forward(fake_AB),得到 pred_fake,然後得到損失,反傳播即可。
注意,這裡反向傳播時,會先將梯度傳到 fake_AB結點,然而我們知道 fake_AB即 fake_B結點,而fake_B正是由netG(real_A)產生的,所以還會順著繼續往前傳播,從而得到G的對應的梯度。

對比 Torch程式碼

df_dg = netD:updateGradInput(fake_AB, df_do):narrow(2,fake_AB:size(2)-output_nc+1, output_nc)
netG:backward(real_A, df_dg + df_do_AE:mul(opt.lambda))

Torch中,沒有計算netD的引數的梯度,而是直接用 updateGradInput。在pytorch中,我們也是希望GAN loss只能更新G。但是pytorch是自動求導的,所以我們沒法手動像Torch一樣只調用updateGradInput

        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

在這裡,雖然pytorch中會自動計算所有的結點的梯度,但是我們執行loss_G.backward()後,按照Torch的理解是,這裡直接呼叫backward。即不僅呼叫了updateGradInput(我們只需要這個),還額外的計算了accGradParameters(這個是沒用的),但是看到,在optimize_parameters中,只是進行 optimizer_G.step()所以只會更新G的引數。所以沒有更新D(雖然此時D中有dummy gradient)。等下一回合,又呼叫 optimizer_D.zero_grad(), 因此會把剛才殘留的D的梯度清空。所以仍舊是符合的。

自動求導反向書寫的簡潔

得出結論,書寫自動求導的程式碼完全還是很簡潔的。只需要進行loss計算。loss可以直接相加,然後loss.backward()即可。loss的定義比如:

self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
            lr=opt.lr, betas=(opt.beta1, 0.999))

Adam是繼承自Optimizer類。該類的step函式會將構建loss的所有的Variable的引數進行更新。

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']: 
             #如果這個引數有沒有grad(這個Variable的requries_grad為False)
             #則直接跳過。
                if p.grad is None:
                    continue
                grad = p.grad.data
                state = self.state[p]

                # 對p.data進行更新!就是對引數進行更新!

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = grad.new().resize_as_(grad).zero_()
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

                denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

                p.data.addcdiv_(-step_size, exp_avg, denom)