1. 程式人生 > >【超解析度】Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

【超解析度】Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

之前我一直在做基於CNN的超解析度研究。最近因為工作需要,需要研究基於生成對抗網路GAN的網路來做超解析度任務。
在這段時間以來,我發現CNN和GAN兩類網路的側重點其實完全不同。CNN旨在於忠實的恢復影象的高頻資訊,而GAN在於生成更真實或者說更符合人眼的高解析度圖片。

基於GAN和基於MSE的超解析度影象的側重點


這裡寫圖片描述

MSE-based解決方案由於畫素化的平均值,會顯得過於平滑,而GAN驅動的超解析度解決方案則對自然影象集合的重構產生了感知上更有說服力的解,看起來更加自然

本文摘要:

在CNN的發展中,超解析度任務在精確度和速度上都有了很大的發展。但是在較大的上取樣因子下,如何去恢復更精細的紋理資訊仍然等待著處理。目前現有的超解析度演算法產生的影象都具有很高的PSNR,但是這類演算法產生的結果往往缺乏高頻細節,而且感知上並不令人滿意。在這篇文章中,作者提出用GAN來做影象超解析度。在本文中,作者提出了一種感知損失perceptual loss,該loss function包含了對抗損失及內容損失。通過這種loss來驅動,使得生成的高解析度影象更加自然。大量的主觀評分(MOS)證明,SRGAN產生的圖片更接近真實影象

本文contributon:

1) 本文提出了SRResNet CNN框架,並且以PSNR和SSIM為優化目的,構造了16個blocks的deep ResNet
2) 本文提出了基於GAN的並以perceptual loss 為目的的SRGAN網路, 在訓練GAN網路時,本文加入VGG loss作為優化目的,VGG loss在畫素空間更具有不變性。
3)本文采取了MOS主觀意見評價方法,對SRGAN生成的影象進行打分。

網路結構:

這裡寫圖片描述

Generator 通過deep ResNet來學習 LR到HR之間的對映,並且以PSNR和SSIM作為其生成指標。同時用判別器來判別生成的圖片是否屬於自然影象。

整個網路結構還是生成對抗網路的套路,目的是去優化min-max problem
這裡寫圖片描述

整個目標函式在GAN中非常常見,也是GAN的本質,其目的是為了訓練一個生成器G來fool 判別器D,再訓練判別器D來判別生成器G生成的圖片還是自然真實影象。

生成器loss

content loss:

該損失函式在超解析度網路中極為常見,即通過求MSE最小來對生成器網路進行更新,然而MSE 優化生成的超分影象往往會缺少高頻資訊,從而使得過平滑。
這裡寫圖片描述

VGG loss:

這個損失函式來自於李飛飛的感知loss的一篇論文,通過將兩張圖片投入VGG網路中,然後求解兩張特徵影象的mse來進行優化。由於VGG loss來自於比較深層的網路提取出來的特徵,因此這個損失更能夠保證感知相似度。
這裡寫圖片描述

Adversarial loss:

這個損失函式在GAN的生成器中極為常見,一般來說如果只用對抗損失會使得網路訓練起來很難收斂,而加入了之前的MSE loss和VGG loss後,能夠保證網路的收斂。
這裡寫圖片描述

為了簡化對抗損失 adversarial loss,我們對上式 用logDθD(GθD(ILR))
來替換 log[1DθD(GθD(ILR))]

那麼生成器的loss則很明確了

(70)gloss=gcontentloss+gVGGloss+gadversarial

程式碼描述為:

    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss + g_gan_loss

判別器loss

判別器loss只有對抗損失函式
這裡寫圖片描述

程式碼描述為:

    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2')
    d_loss = d_loss1 + d_loss2

實驗結果:

這裡寫圖片描述