1. 程式人生 > >對抗神經網路學習(八)——DeblurGAN實現運動影象的去模糊化(tensorflow實現)

對抗神經網路學習(八)——DeblurGAN實現運動影象的去模糊化(tensorflow實現)

一、背景

DeblurGAN是Orest Kupyn等人於17年11月提出的一種模型。前面學習過,GAN可以儲存影像的細節紋理特徵,比如之前做過的SRGAN可以實現影象的超解析度,因此,作者利用這個特點,結合GAN和多元內容損失來構建DeblurGAN,以實現對運動影象的去模糊化。

本試驗的資料集為GOPRO資料,後面還會有詳細的介紹,儘可能用比較少的程式碼實現DeblurGAN。

[1]文章連結:https://arxiv.org/pdf/1711.07064.pdf

二、DeblurGAN原理

DeblurGAN的創新主要是結合了之前一些GAN的網路結構和loss函式,網上的介紹比較少,先推薦一篇:

[2]《DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks》論文閱讀之DeblurGAN

文章中,作者執行DeblurGAN的效果圖為:

從左至右依次為:左邊模糊影像,中間DeblurGAN生成影像,右邊為真實影像。可以看到效果還是非常好的。

這篇文章比較短,作者簡要的提出了他們的主要貢獻:

We make three contributions. First, we propose a loss and architecture which obtain state-of-the art results in motion deblurring, while being 5x faster than the fastest competitor. Second, we present a method based on random trajectories

for generating a dataset for motion deblurring training in an automated fashion from the set of sharp image. We show that combining it with an existing dataset for motion deblurring learning improves results compared to training on real-world images only. Finally, we present a novel dataset and method for evaluation of deblurring algorithms
based on how they improve object detection results.

1. 提出了去模糊化的loss函式和模型結構,速度是目前最快編譯器的5倍多。

2. 對於原始的清晰影像,用隨機軌道法來生成模糊影像作為資料集。

3. 提出去模糊化演算法,提高目標檢測結果。

本文重點介紹DeblurGAN的實現過程,關於如何生成資料集,可以參考[2]中的介紹或者檢視原文,這裡只給出生成資料集的示意圖和簡要介紹,大概類似於對相機長曝光並抖動而產生的影像:

簡單的說,就是對清晰影象卷積上各式各樣的“blur kernel”,獲得合成的模糊影象。作者採用了運動軌跡隨機生成方法(用馬爾科夫隨機過程生成);然後對軌跡進行“sub-pixel interpolation”生成blur kernel。當然,這種方法也只能在二維平面空間中生成軌跡,並不能模擬真實空間中6D相機的運動[2]。

同時作者也給出了生成模糊影像演算法的虛擬碼:

關於模型的網路結構,其實總的來看和普通的GAN並沒有什麼大的區別:

不過作者所採用的生成器generator的網路結構則類似於自編碼器(auto-encoder):

而判別器的網路結構則與PatchGAN相同。

另外,作者提到了他對loss函式進行了改進,令新的loss函式為Content loss與Adversarial loss之和。

關於DeblurGAN的實現程式碼,我主要參考了[3],並對該程式碼進行了修改。另外,網上的參考程式碼非常少,這裡再給出幾個:

[3]https://github.com/dongheehand/DeblurGAN-tf

[4]https://github.com/LeeDoYup/DeblurGAN-tf

[5]https://github.com/KupynOrest/DeblurGAN

三、DeblurGAN實現

1. 檔案結構

所有的檔案結構如下:

-- main.py
-- util.py
-- data_loader.py
-- mode.py
-- DeblurGAN.py
-- vgg19.py
-- layer.py
-- vgg19.npy                            # 這個是需要自己下載的vgg19模型,後面會說明
-- data                                 # 這個是訓練資料集,後面也會具體說明
    |------ train
            |------ blur
                    |------ image1.png
                    |------ image2.png
                    |------ ......
            |------ sharp
                    |------ image1.png
                    |------ image2.png
                    |------ ......
    |------ val
            |------ val_blur
                    |------ image1.png
                    |------ image2.png
                    |------ ......
            |------ val_sharp
                    |------ image1.png
                    |------ image2.png
                    |------ ......

2. 資料集準備

這裡需要準備的資料有兩個,一個是vgg19模型檔案,另一個是訓練資料集。

(1)vgg19.npy模型檔案

先給出vgg19.npy的下載地址;

https://mega.nz/#!xZ8glS6J!MAnE91ND_WyfZ_8mvkuSa2YcA7q-1ehfSm-Q1fxOvvs

開啟上述網址,直接下載即可,不過需要注意的是,該檔案需要翻牆下載:

為了方便大家的使用,我將該資料上傳到了百度雲上。下載地址為:

百度雲地址:https://pan.baidu.com/s/1GluBif6N1u9eiosICI12Ng

提取碼:dzsa

下載好該檔案之後,將該檔案放到專案的根目錄下即可,即'./vgg19.npy'。

(2)訓練資料集dataset

關於GOPRO的資料集,網上有不同的版本,這裡先給出GOPRO的資料集簡要介紹及下載地址,需要注意的是,下載需要翻牆:

①GOPRO_Large:該資料集的大小為8.9G,下載連結為(需要翻牆):

https://drive.google.com/uc?id=1H0PIXvJH4c40pk7ou6nAwoxuR4Qh_Sa2&export=download

②GOPRO_Large_all:該資料集的大小為35G,下載連結為(需要翻牆):

https://drive.google.com/uc?id=1SlURvdQsokgsoyTosAaELc4zRjQz9T2U&export=download

③blurred_sharp.zip:該資料集的大小為1.0G,下載連結為(需要翻牆):

https://drive.google.com/uc?export=download&confirm=jg11&id=1CPMBmRj-jBDO2ax4CxkBs9iczIFrs8VA

如果有辦法能夠開啟上述連結,就直接開啟並下載即可:

為了防止無法開啟上述連結,我將該資料集上傳至百度雲。下載地址為:

百度雲地址:https://pan.baidu.com/s/1PG_yzQqEu6qYr7qQSfyW0Q

提取碼:58u2

下載好該資料後解壓,在路徑'./blurred_sharp/blurred_sharp/'下,可以看到'blurred'和'sharp'兩個資料夾,這裡都是我們的訓練資料,將'blurred'資料夾下的所有影象移至'./data/train/blur/'資料夾下,將'sharp'資料夾下的所有影象移動至'./data/train/sharp/'資料夾下,這樣就製作好了訓練資料,但是我們還需要拿出一部分資料作為測試資料。

我是將'./data/train/blur/'中的5張圖片剪下至'./data/val/val_blur/'中,同理,將相應編號的'./data/train/sharp/'中的5張圖片剪下至'./data/val/val_sharp/'中,需要注意的是這兩組圖片的編號必須一致對應。

構建好的資料集為:

開啟這些照片的屬性資訊可以看到,所有的照片的大小都為720*720,格式為png。構建好資料集之後,就可以開始試驗了。

3. 資料載入檔案data_loader.py

data_loader.py檔案中主要編寫一些載入資料的函式,下面直接給出程式碼:

import tensorflow as tf
import os

class dataloader():
    
    def __init__(self, args):

        self.channel = 3

        self.mode = args.mode
        self.patch_size = args.patch_size
        self.batch_size = args.batch_size
        self.train_Sharp_path = args.train_Sharp_path
        self.train_Blur_path = args.train_Blur_path
        self.test_Sharp_path = args.test_Sharp_path
        self.test_Blur_path = args.test_Blur_path
        self.test_with_train = args.test_with_train
        self.test_batch = args.test_batch
        self.load_X = args.load_X
        self.load_Y = args.load_Y
        self.augmentation = args.augmentation
        
    def build_loader(self):
        
        if self.mode == 'train':
        
            tr_sharp_imgs = sorted(os.listdir(self.train_Sharp_path))
            tr_blur_imgs = sorted(os.listdir(self.train_Blur_path))
            tr_sharp_imgs = [os.path.join(self.train_Sharp_path, ele) for ele in tr_sharp_imgs]
            tr_blur_imgs = [os.path.join(self.train_Blur_path, ele) for ele in tr_blur_imgs]
            train_list = (tr_blur_imgs, tr_sharp_imgs)
            
            self.tr_dataset = tf.data.Dataset.from_tensor_slices(train_list)
            self.tr_dataset = self.tr_dataset.map(self._parse, num_parallel_calls = 4).prefetch(32)
            self.tr_dataset = self.tr_dataset.map(self._resize, num_parallel_calls = 4).prefetch(32)
            self.tr_dataset = self.tr_dataset.map(self._get_patch, num_parallel_calls = 4).prefetch(32)
            if self.augmentation:
                self.tr_dataset = self.tr_dataset.map(self._data_augmentation, num_parallel_calls = 4).prefetch(32)
            self.tr_dataset = self.tr_dataset.shuffle(32)
            self.tr_dataset = self.tr_dataset.repeat()
            self.tr_dataset = self.tr_dataset.batch(self.batch_size)
            
            if self.test_with_train:
            
                val_sharp_imgs = sorted(os.listdir(self.test_Sharp_path))
                val_blur_imgs = sorted(os.listdir(self.test_Blur_path))
                val_sharp_imgs = [os.path.join(self.test_Sharp_path, ele) for ele in val_sharp_imgs]
                val_blur_imgs = [os.path.join(self.test_Blur_path, ele) for ele in val_blur_imgs]
                valid_list = (val_blur_imgs, val_sharp_imgs)

                self.val_dataset = tf.data.Dataset.from_tensor_slices(valid_list)
                self.val_dataset = self.val_dataset.map(self._parse, num_parallel_calls=4).prefetch(32)
                self.val_dataset = self.val_dataset.batch(self.test_batch)

            iterator = tf.data.Iterator.from_structure(self.tr_dataset.output_types, self.tr_dataset.output_shapes)
            self.next_batch = iterator.get_next()
            self.init_op = {}
            self.init_op['tr_init'] = iterator.make_initializer(self.tr_dataset)
            
            if self.test_with_train:
                self.init_op['val_init'] = iterator.make_initializer(self.val_dataset)
            
        elif self.mode == 'test':
            
            val_sharp_imgs = sorted(os.listdir(self.test_Sharp_path))
            val_blur_imgs = sorted(os.listdir(self.test_Blur_path))
            val_sharp_imgs = [os.path.join(self.test_Sharp_path, ele) for ele in val_sharp_imgs]
            val_blur_imgs = [os.path.join(self.test_Blur_path, ele) for ele in val_blur_imgs]
            valid_list = (val_blur_imgs, val_sharp_imgs)
            
            self.val_dataset = tf.data.Dataset.from_tensor_slices(valid_list)
            self.val_dataset = self.val_dataset.map(self._parse, num_parallel_calls=4).prefetch(32)
            self.val_dataset = self.val_dataset.batch(1)
            
            iterator = tf.data.Iterator.from_structure(self.val_dataset.output_types, self.val_dataset.output_shapes)
            self.next_batch = iterator.get_next()
            self.init_op = {}
            self.init_op['val_init'] = iterator.make_initializer(self.val_dataset)
             
    def _parse(self, image_blur, image_sharp):
        
        image_blur = tf.read_file(image_blur)
        image_sharp = tf.read_file(image_sharp)
        
        image_blur = tf.image.decode_image(image_blur, channels=self.channel)
        image_sharp = tf.image.decode_image(image_sharp, channels=self.channel)
        
        image_blur = tf.cast(image_blur, tf.float32)
        image_sharp = tf.cast(image_sharp, tf.float32)
        
        return image_blur, image_sharp
    
    def _resize(self, image_blur, image_sharp):
        
        image_blur = tf.image.resize_images(image_blur, (self.load_Y, self.load_X), tf.image.ResizeMethod.BICUBIC)
        image_sharp = tf.image.resize_images(image_sharp, (self.load_Y, self.load_X), tf.image.ResizeMethod.BICUBIC)
        
        return image_blur, image_sharp

    def _parse_Blur_only(self, image_blur):
        
        image_blur = tf.read_file(image_blur)
        image_blur = tf.image.decode_image(image_blur, channels=self.channel)
        image_blur = tf.cast(image_blur, tf.float32)
        
        return image_blur
        
    def _get_patch(self, image_blur, image_sharp):
        
        shape = tf.shape(image_blur)
        ih = shape[0]
        iw = shape[1]
        
        ix = tf.random_uniform(shape=[1], minval=0, maxval=iw - self.patch_size + 1, dtype=tf.int32)[0]
        iy = tf.random_uniform(shape=[1], minval=0, maxval=ih - self.patch_size + 1, dtype=tf.int32)[0]
        
        img_sharp_in = image_sharp[iy:iy + self.patch_size, ix:ix + self.patch_size]        
        img_blur_in = image_blur[iy:iy + self.patch_size, ix:ix + self.patch_size]
        
        return img_blur_in, img_sharp_in
    
    def _data_augmentation(self, image_blur, image_sharp):
        
        rot = tf.random_uniform(shape=[1], minval=0, maxval=3, dtype=tf.int32)[0]
        flip_rl = tf.random_uniform(shape=[1], minval=0, maxval=3, dtype=tf.int32)[0]
        flip_updown = tf.random_uniform(shape=[1], minval=0, maxval=3, dtype=tf.int32)[0]
        
        image_blur = tf.image.rot90(image_blur, rot)
        image_sharp = tf.image.rot90(image_sharp, rot)
        
        rl = tf.equal(tf.mod(flip_rl, 2), 0)
        ud = tf.equal(tf.mod(flip_updown, 2), 0)
        
        image_blur = tf.cond(rl, true_fn=lambda: tf.image.flip_left_right(image_blur),
                             false_fn=lambda: image_blur)
        image_sharp = tf.cond(rl, true_fn=lambda: tf.image.flip_left_right(image_sharp),
                              false_fn=lambda: image_sharp)
        
        image_blur = tf.cond(ud, true_fn=lambda: tf.image.flip_up_down(image_blur),
                             false_fn=lambda: image_blur)
        image_sharp = tf.cond(ud, true_fn=lambda: tf.image.flip_up_down(image_sharp),
                              false_fn=lambda: image_sharp)
        
        return image_blur, image_sharp

4. vgg19檔案vgg19.py

vgg19.py檔案主要是用來載入vgg19模型的,這裡直接給出程式碼:

import tensorflow as tf
import numpy as np
import time


VGG_MEAN = [103.939, 116.779, 123.68]


class Vgg19:

    def __init__(self, vgg19_npy_path):
        self.data_dict = np.load(vgg19_npy_path, encoding='latin1').item()
        print("npy file loaded")

    def build(self, rgb):
        """
        load variable from npy to build the VGG
        :param rgb: rgb image [batch, height, width, 3] values scaled [-1, 1]
        """

        start_time = time.time()
        print("build vgg19 model started")
        rgb_scaled = ((rgb + 1) * 255.0) / 2.0

        # Convert RGB to BGR
        red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled)
        bgr = tf.concat(axis=3, values=[blue - VGG_MEAN[0], green - VGG_MEAN[1], red - VGG_MEAN[2]])

        self.conv1_1 = self.conv_layer(bgr, "conv1_1")
        self.relu1_1 = self.relu_layer(self.conv1_1, "relu1_1")
        self.conv1_2 = self.conv_layer(self.relu1_1, "conv1_2")
        self.relu1_2 = self.relu_layer(self.conv1_2, "relu1_2")
        self.pool1 = self.max_pool(self.relu1_2, 'pool1')

        self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
        self.relu2_1 = self.relu_layer(self.conv2_1, "relu2_1")
        self.conv2_2 = self.conv_layer(self.relu2_1, "conv2_2")
        self.relu2_2 = self.relu_layer(self.conv2_2, "relu2_2")
        self.pool2 = self.max_pool(self.relu2_2, 'pool2')

        self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
        self.relu3_1 = self.relu_layer(self.conv3_1, "relu3_1")
        self.conv3_2 = self.conv_layer(self.relu3_1, "conv3_2")
        self.relu3_2 = self.relu_layer(self.conv3_2, "relu3_2")
        self.conv3_3 = self.conv_layer(self.relu3_2, "conv3_3")
        self.relu3_3 = self.relu_layer(self.conv3_3, "relu3_3")
        self.conv3_4 = self.conv_layer(self.relu3_3, "conv3_4")
        self.relu3_4 = self.relu_layer(self.conv3_4, "relu3_4")
        self.pool3 = self.max_pool(self.relu3_4, 'pool3')

        self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
        self.relu4_1 = self.relu_layer(self.conv4_1, "relu4_1")
        self.conv4_2 = self.conv_layer(self.relu4_1, "conv4_2")
        self.relu4_2 = self.relu_layer(self.conv4_2, "relu4_2")
        self.conv4_3 = self.conv_layer(self.relu4_2, "conv4_3")
        self.relu4_3 = self.relu_layer(self.conv4_3, "relu4_3")
        self.conv4_4 = self.conv_layer(self.relu4_3, "conv4_4")
        self.relu4_4 = self.relu_layer(self.conv4_4, "relu4_4")
        self.pool4 = self.max_pool(self.relu4_4, 'pool4')

        self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")
        self.relu5_1 = self.relu_layer(self.conv5_1, "relu5_1")
        self.conv5_2 = self.conv_layer(self.relu5_1, "conv5_2")
        self.relu5_2 = self.relu_layer(self.conv5_2, "relu5_2")
        self.conv5_3 = self.conv_layer(self.relu5_2, "conv5_3")
        self.relu5_3 = self.relu_layer(self.conv5_3, "relu5_3")
        self.conv5_4 = self.conv_layer(self.relu5_3, "conv5_4")
        self.relu5_4 = self.relu_layer(self.conv5_4, "relu5_4")
        self.pool5 = self.max_pool(self.conv5_4, 'pool5')

        self.data_dict = None
        print(("build vgg19 model finished: %ds" % (time.time() - start_time)))

    def max_pool(self, bottom, name):
        return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)
    
    def relu_layer(self, bottom, name):
        return tf.nn.relu(bottom, name=name)

    def conv_layer(self, bottom, name):
        with tf.variable_scope(name):
            filt = self.get_conv_filter(name)

            conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')

            conv_biases = self.get_bias(name)
            bias = tf.nn.bias_add(conv, conv_biases)
                   
            return bias

    def get_conv_filter(self, name):
        return tf.constant(self.data_dict[name][0], name="filter")

    def get_bias(self, name):
        return tf.constant(self.data_dict[name][1], name="biases")

5. 影象處理檔案util.py

由於資料其實是成對出現的,所以在util.py檔案中需要將讀取到的資料成對處理,下面給出程式碼:

from PIL import Image
import numpy as np
import random
import os

def image_loader(image_path, load_x, load_y, is_train = True):
    
    imgs = sorted(os.listdir(image_path))
    img_list = []
    for ele in imgs:
        img = Image.open(os.path.join(image_path, ele))
        if is_train:
            img = img.resize((load_x, load_y), Image.BICUBIC)
        img_list.append(np.array(img))
    
    return img_list

def data_augument(lr_img, hr_img, aug):
    
    if aug < 4:
        lr_img = np.rot90(lr_img, aug)
        hr_img = np.rot90(hr_img, aug)
    
    elif aug == 4:
        lr_img = np.fliplr(lr_img)
        hr_img = np.fliplr(hr_img)
        
    elif aug == 5:
        lr_img = np.flipud(lr_img)
        hr_img = np.flipud(hr_img)
        
    elif aug == 6:
        lr_img = np.rot90(np.fliplr(lr_img))
        hr_img = np.rot90(np.fliplr(hr_img))
        
    elif aug == 7:
        lr_img = np.rot90(np.flipud(lr_img))
        hr_img = np.rot90(np.flipud(hr_img))
        
    return lr_img, hr_img

def batch_gen(blur_imgs, sharp_imgs, patch_size, batch_size, random_index, step, augment=False):
    
    img_index = random_index[step * batch_size: (step + 1) * batch_size]
    
    all_img_blur = []
    all_img_sharp = []
    
    for _index in img_index:
        all_img_blur.append(blur_imgs[_index])
        all_img_sharp.append(sharp_imgs[_index])
    
    blur_batch = []
    sharp_batch = []
    
    for i in range(len(all_img_blur)):
        
        ih, iw, _ = all_img_blur[i].shape
        ix = random.randrange(0, iw - patch_size + 1)
        iy = random.randrange(0, ih - patch_size + 1)
        
        img_blur_in = all_img_blur[i][iy:iy + patch_size, ix:ix + patch_size]
        img_sharp_in = all_img_sharp[i][iy:iy + patch_size, ix:ix + patch_size]        
        
        if augment:
            
            aug = random.randrange(0, 8)
            img_blur_in, img_sharp_in = data_augument(img_blur_in, img_sharp_in, aug)

        blur_batch.append(img_blur_in)
        sharp_batch.append(img_sharp_in)
        
    blur_batch = np.array(blur_batch)
    sharp_batch = np.array(sharp_batch)
    
    return blur_batch, sharp_batch

6. 圖層檔案layer.py

DeblurGAN中用到的卷積、反捲積、以及norm層都在layer檔案中進行定義,程式碼為:

import tensorflow as tf
import numpy as np


def Conv(name, x, filter_size, in_filters, out_filters, strides, padding):
    with tf.variable_scope(name):
        kernel = tf.get_variable('filter', [filter_size, filter_size, in_filters, out_filters], tf.float32,
                                 initializer=tf.random_normal_initializer(stddev=0.01))
        bias = tf.get_variable('bias', [out_filters], tf.float32, initializer=tf.zeros_initializer())
        
        return tf.nn.conv2d(x, kernel, [1, strides, strides, 1], padding=padding) + bias
    

def Conv_transpose(name, x, filter_size, in_filters, out_filters, fraction=2, padding="SAME"):
    with tf.variable_scope(name):
        n = filter_size * filter_size * out_filters
        kernel = tf.get_variable('filter', [filter_size, filter_size, out_filters, in_filters], tf.float32,
                                 initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0/n)))
        size = tf.shape(x)
        output_shape = tf.stack([size[0], size[1] * fraction, size[2] * fraction, out_filters])
        x = tf.nn.conv2d_transpose(x, kernel, output_shape, [1, fraction, fraction, 1], padding)
        
        return x


def instance_norm(x, BN_epsilon=1e-3):
    mean, variance = tf.nn.moments(x, axes=[1, 2])
    x = (x - mean) / ((variance + BN_epsilon) ** 0.5)
    return x

7. 構建模型檔案DeblurGAN.py

前面的檔案都是在做一些準備工作,這一步才是需要建立DeblurGAN模型,程式碼為:

from layer import *
from data_loader import dataloader
from vgg19 import Vgg19


class DeblurGAN():
    
    def __init__(self, args):
        
        self.data_loader = dataloader(args)
        print("data has been loaded")

        self.channel = 3

        self.n_feats = args.n_feats
        self.mode = args.mode
        self.batch_size = args.batch_size      
        self.num_of_down_scale = args.num_of_down_scale
        self.gen_resblocks = args.gen_resblocks
        self.discrim_blocks = args.discrim_blocks
        self.vgg_path = args.vgg_path
        
        self.learning_rate = args.learning_rate
        self.decay_step = args.decay_step
        
    def down_scaling_feature(self, name, x, n_feats):
        x = Conv(name=name + 'conv', x=x, filter_size=3, in_filters=n_feats,
                 out_filters=n_feats * 2, strides=2, padding='SAME')
        x = instance_norm(x)
        x = tf.nn.relu(x)
        
        return x
    
    def up_scaling_feature(self, name, x, n_feats):
        x = Conv_transpose(name=name + 'deconv', x=x, filter_size=3, in_filters=n_feats,
                           out_filters=n_feats // 2, fraction=2, padding='SAME')
        x = instance_norm(x)
        x = tf.nn.relu(x)
        
        return x
    
    def res_block(self, name, x, n_feats):
        
        _res = x
        
        x = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        x = Conv(name=name + 'conv1', x=x, filter_size=3, in_filters=n_feats,
                 out_filters=n_feats, strides=1, padding='VALID')
        x = instance_norm(x)
        x = tf.nn.relu(x)
        
        x = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        x = Conv(name=name + 'conv2', x=x, filter_size=3, in_filters=n_feats,
                 out_filters=n_feats, strides=1, padding='VALID')
        x = instance_norm(x)
        
        x = x + _res
        
        return x
    
    def generator(self, x, reuse=False, name='generator'):
        
        with tf.variable_scope(name_or_scope=name, reuse=reuse):
            _res = x
            x = tf.pad(x, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
            x = Conv(name='conv1', x=x, filter_size=7, in_filters=self.channel,
                     out_filters=self.n_feats, strides=1, padding='VALID')
            # x = instance_norm(name = 'inst_norm1', x = x, dim = self.n_feats)
            x = instance_norm(x)
            x = tf.nn.relu(x)
            
            for i in range(self.num_of_down_scale):
                x = self.down_scaling_feature(name='down_%02d' % i, x=x, n_feats=self.n_feats * (i + 1))

            for i in range(self.gen_resblocks):
                x = self.res_block(name='res_%02d' % i, x=x, n_feats=self.n_feats * (2 ** self.num_of_down_scale))

            for i in range(self.num_of_down_scale):
                x = self.up_scaling_feature(name='up_%02d' % i, x=x,
                                            n_feats=self.n_feats * (2 ** (self.num_of_down_scale - i)))

            x = tf.pad(x, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
            x = Conv(name='conv_last', x=x, filter_size=7, in_filters=self.n_feats,
                     out_filters=self.channel, strides=1, padding='VALID')
            x = tf.nn.tanh(x)
            x = x + _res
            x = tf.clip_by_value(x, -1.0, 1.0)
            
            return x
    
    def discriminator(self, x, reuse=False, name='discriminator'):
        
        with tf.variable_scope(name_or_scope=name, reuse=reuse):
            x = Conv(name='conv1', x=x, filter_size=4, in_filters=self.channel,
                     out_filters=self.n_feats, strides=2, padding="SAME")
            x = instance_norm(x)
            x = tf.nn.leaky_relu(x)
            
            n = 1
            
            for i in range(self.discrim_blocks):
                prev = n
                n = min(2 ** (i+1), 8)
                x = Conv(name='conv%02d' % i, x=x, filter_size=4, in_filters=self.n_feats * prev,
                         out_filters=self.n_feats * n, strides=2, padding="SAME")
                x = instance_norm(x)
                x = tf.nn.leaky_relu(x)
                
            prev = n
            n = min(2 ** self.discrim_blocks, 8)
            x = Conv(name='conv_d1', x=x, filter_size=4, in_filters=self.n_feats * prev,
                     out_filters=self.n_feats * n, strides=1, padding="SAME")
            # x = instance_norm(name = 'instance_norm_d1', x = x, dim = self.n_feats * n)
            x = instance_norm(x)
            x = tf.nn.leaky_relu(x)
            
            x = Conv(name='conv_d2', x=x, filter_size=4, in_filters=self.n_feats * n,
                     out_filters=1, strides=1, padding="SAME")
            x = tf.nn.sigmoid(x)
            
            return x
    
        
    def build_graph(self):
        
        # if self.in_memory:
        self.blur = tf.placeholder(name="blur", shape=[None, None, None, self.channel], dtype=tf.float32)
        self.sharp = tf.placeholder(name="sharp", shape=[None, None, None, self.channel], dtype=tf.float32)
            
        x = self.blur
        label = self.sharp
        
        self.epoch = tf.placeholder(name='train_step', shape=None, dtype=tf.int32)
        
        x = (2.0 * x / 255.0) - 1.0
        label = (2.0 * label / 255.0) - 1.0
        
        self.gene_img = self.generator(x, reuse=False)
        self.real_prob = self.discriminator(label, reuse=False)
        self.fake_prob = self.discriminator(self.gene_img, reuse=True)
        
        epsilon = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0.0, maxval=1.0)
        
        interpolated_input = epsilon * label + (1 - epsilon) * self.gene_img
        gradient = tf.gradients(self.discriminator(interpolated_input, reuse=True), [interpolated_input])[0]
        GP_loss = tf.reduce_mean(tf.square(tf.sqrt(tf.reduce_mean(tf.square(gradient), axis=[1, 2, 3])) - 1))
        
        d_loss_real = - tf.reduce_mean(self.real_prob)
        d_loss_fake = tf.reduce_mean(self.fake_prob)
        
        self.vgg_net = Vgg19(self.vgg_path)
        self.vgg_net.build(tf.concat([label, self.gene_img], axis=0))
        self.content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(
            self.vgg_net.relu3_3[self.batch_size:] - self.vgg_net.relu3_3[:self.batch_size]), axis=3))
        
        self.D_loss = d_loss_real + d_loss_fake + 10.0 * GP_loss
        self.G_loss = - d_loss_fake + 100.0 * self.content_loss
        
        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'generator' in var.name]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]
        
        lr = tf.minimum(self.learning_rate, tf.abs(2 * self.learning_rate - (
                self.learning_rate * tf.cast(self.epoch, tf.float32) / self.decay_step)))
        self.D_train = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.D_loss, var_list=D_vars)
        self.G_train = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.G_loss, var_list=G_vars)
        
        self.PSNR = tf.reduce_mean(tf.image.psnr(((self.gene_img + 1.0) / 2.0), ((label + 1.0) / 2.0), max_val=1.0))
        self.ssim = tf.reduce_mean(tf.image.ssim(((self.gene_img + 1.0) / 2.0), ((label + 1.0) / 2.0), max_val=1.0))
        
        logging_D_loss = tf.summary.scalar(name='D_loss', tensor=self.D_loss)
        logging_G_loss = tf.summary.scalar(name='G_loss', tensor=self.G_loss)
        logging_PSNR = tf.summary.scalar(name='PSNR', tensor=self.PSNR)
        logging_ssim = tf.summary.scalar(name='ssim', tensor=self.ssim)

        self.output = (self.gene_img + 1.0) * 255.0 / 2.0
        self.output = tf.round(self.output)
        self.output = tf.cast(self.output, tf.uint8)

8. 試驗過程檔案mode.py

mode.py檔案主要編寫train和test函式,不過這個檔案其實可以和main檔案進行合併,先給出程式碼:

import os
import tensorflow as tf
from PIL import Image
import numpy as np
import time
import util


def train(args, model, sess, saver):
    
    if args.fine_tuning:
        saver.restore(sess, args.pre_trained_model)
        print("saved model is loaded for fine-tuning!")
        print("model path is %s" % args.pre_trained_model)
        
    num_imgs = len(os.listdir(args.train_Sharp_path))
    
    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter('./logs', sess.graph)
    if args.test_with_train:
        f = open("valid_logs.txt", 'w')
    
    epoch = 0
    step = num_imgs // args.batch_size
           
    blur_imgs = util.image_loader(args.train_Blur_path, args.load_X, args.load_Y)
    sharp_imgs = util.image_loader(args.train_Sharp_path, args.load_X, args.load_Y)
        
    while epoch < args.max_epoch:
        random_index = np.random.permutation(len(blur_imgs))
        for k in range(step):
            s_time = time.time()
            blur_batch, sharp_batch = util.batch_gen(blur_imgs, sharp_imgs, args.patch_size,
                                                     args.batch_size, random_index, k)
                
            for t in range(args.critic_updates):
                _, D_loss = sess.run([model.D_train, model.D_loss],
                                     feed_dict={model.blur: blur_batch, model.sharp: sharp_batch, model.epoch: epoch})
                    
            _, G_loss = sess.run([model.G_train, model.G_loss],
                                 feed_dict={model.blur: blur_batch, model.sharp: sharp_batch, model.epoch: epoch})
                             
            e_time = time.time()
            
        if epoch % args.log_freq == 0:
            summary = sess.run(merged, feed_dict={model.blur: blur_batch, model.sharp: sharp_batch})
            train_writer.add_summary(summary, epoch)
            if args.test_with_train:
                test(args, model, sess, saver, f, epoch, loading=False)
            print("%d training epoch completed" % epoch)
            print("D_loss : {}, \t G_loss : {}".format(D_loss, G_loss))
            print("Elpased time : %0.4f" % (e_time - s_time))
            # print("D_loss : %0.4f, \t G_loss : %0.4f" % (D_loss, G_loss))
            # print("Elpased time : %0.4f" % (e_time - s_time))
        if (epoch) % args.model_save_freq == 0:
            saver.save(sess, './model/DeblurrGAN', global_step=epoch, write_meta_graph=False)
            
        epoch += 1

    saver.save(sess, './model/DeblurrGAN_last', write_meta_graph=False)
    
    if args.test_with_train:
        f.close()
        
        
def test(args, model, sess, saver, file, step=-1, loading=False):
        
    if loading:

        import re
        print(" [*] Reading checkpoints...")
        ckpt = tf.train.get_checkpoint_state(args.pre_trained_model)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            saver.restore(sess, os.path.join(args.pre_trained_model, ckpt_name))
            print(" [*] Success to read {}".format(ckpt_name))
        else:
            print(" [*] Failed to find a checkpoint")
     
    blur_img_name = sorted(os.listdir(args.test_Blur_path))
    sharp_img_name = sorted(os.listdir(args.test_Sharp_path))
    
    PSNR_list = []
    ssim_list = []
        
    blur_imgs = util.image_loader(args.test_Blur_path, args.load_X, args.load_Y, is_train=False)
    sharp_imgs = util.image_loader(args.test_Sharp_path, args.load_X, args.load_Y, is_train=False)

    if not os.path.exists('./result/'):
        os.makedirs('./result/')

    for i, ele in enumerate(blur_imgs):
        blur = np.expand_dims(ele, axis = 0)
        sharp = np.expand_dims(sharp_imgs[i], axis = 0)
        output, psnr, ssim = sess.run([model.output, model.PSNR, model.ssim], feed_dict = {model.blur : blur, model.sharp : sharp})
        if args.save_test_result:
            output = Image.fromarray(output[0])
            split_name = blur_img_name[i].split('.')
            output.save(os.path.join(args.result_path, '%s_sharp.png'%(''.join(map(str, split_name[:-1])))))

        PSNR_list.append(psnr)
        ssim_list.append(ssim)
            
    length = len(PSNR_list)
    
    mean_PSNR = sum(PSNR_list) / length
    mean_ssim = sum(ssim_list) / length
    
    if step == -1:
        file.write('PSNR : {} SSIM : {}' .format(mean_PSNR, mean_ssim))
        file.close()
        
    else:
        file.write("{}d-epoch step PSNR : {} SSIM : {} \n".format(step, mean_PSNR, mean_ssim))

9. 引數設定的主檔案main.py

最後就是main.py檔案了,主要是引數設定,然後執行模型即可。程式碼為:

import tensorflow as tf
from DeblurGAN import DeblurGAN
from mode import *
import argparse

parser = argparse.ArgumentParser()


def str2bool(v):
    return v.lower() in ('true')


## Model specification
parser.add_argument("--n_feats", type=int, default=64)
parser.add_argument("--num_of_down_scale", type=int, default=2)
parser.add_argument("--gen_resblocks", type=int, default=9)
parser.add_argument("--discrim_blocks", type=int, default=3)

## Data specification 
parser.add_argument("--train_Sharp_path", type=str, default="./data/train/sharp/")
parser.add_argument("--train_Blur_path", type=str, default="./data/train/blur")
parser.add_argument("--test_Sharp_path", type=str, default="./data/val/val_sharp")
parser.add_argument("--test_Blur_path", type=str, default="./data/val/val_blur")
parser.add_argument("--vgg_path", type=str, default="./vgg19.npy")
parser.add_argument("--patch_size", type=int, default=256)
parser.add_argument("--result_path", type=str, default="./result")
parser.add_argument("--model_path", type=str, default="./model")

## Optimization
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_epoch", type=int, default=200)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--decay_step", type=int, default=150)
parser.add_argument("--test_with_train", type=str2bool, default=True)
parser.add_argument("--save_test_result", type=str2bool, default=True)

## Training or test specification
parser.add_argument("--mode", type=str, default="train")
parser.add_argument("--critic_updates", type=int, default=5)
parser.add_argument("--augmentation", type=str2bool, default=False)
parser.add_argument("--load_X", type=int, default=640)
parser.add_argument("--load_Y", type=int, default=360)
parser.add_argument("--fine_tuning", type=str2bool, default=False)
parser.add_argument("--log_freq", type=int, default=1)
parser.add_argument("--model_save_freq", type=int, default=20)
parser.add_argument("--pre_trained_model", type=str, default="./model/")
parser.add_argument("--test_batch", type=int, default=5)
args = parser.parse_args()

model = DeblurGAN(args)
model.build_graph()

print("Build DeblurGAN model!")

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(max_to_keep=None)

if args.mode == 'train':
    train(args, model, sess, saver)
    
elif args.mode == 'test':
    f = open("test_results.txt", 'w')
    test(args, model, sess, saver, f, step=-1, loading=True)
    f.close()
    

四、試驗結果

準備好了所有檔案之後,下面是關於模型的執行。首先需要訓練函式,將main.py程式碼中的mode引數設定為train,然後執行訓練即可:

parser.add_argument("--mode", type=str, default="train")

最開始我是設定epoch為300,每50個epoch儲存一次模型結果。但是用GPU(GTX1060 3G)訓練了一晚上,只訓練了51個epoch,因此我將上面的epoch相關引數設定小了一些。最終我只用訓練50個epoch的模型進行測試。

測試的時候,需要修改上面的mode引數,將其改為test,然後就可以直接開始執行程式碼:

parser.add_argument("--mode", type=str, default="test")

下面直接給出執行的試驗結果:

粗略以看效果還不錯,下面可以放大看看細節上的恢復效果:

放大來看的話,相比於blur影像,確實可以明顯的感覺影象清晰了很多,但是也許是訓練的次數還不夠或者是原影象過度模糊難以復原,放大了看仍有一些地方比較模糊。

五、分析

1. 檔案結構見三

2. DeblurGAN開創性的用GAN做了影象去模糊化的工作。