1. 程式人生 > >對抗神經網路學習(十)——attentiveGAN實現影像去雨滴的過程(tensorflow實現)

對抗神經網路學習(十)——attentiveGAN實現影像去雨滴的過程(tensorflow實現)

一、背景

attentiveGAN是Rui Qian等人於17年11月份提出的一種模型。《Attentive Generative Adversarial Network for Raindrop Removal  from A Single Image》在generator網路中引入了attention map,提高了影像中雨滴的去除效果。

本實驗主要參考程式碼[2],進行了簡單改進,用較短的程式碼實現該過程。

[1]文章連結: https://arxiv.org/pdf/1711.10098.pdf

[2]attentive-GAN-derainnet: https://github.com/MaybeShewill-CV/attentive-gan-derainnet

二、attentiveGAN原理

attentiveGAN網上的介紹並不多,這裡推薦一篇:

[3]效果驚豔!北大團隊提出Attentive GAN去除影象中雨滴

下面我們來看一看原文中作者的一些描述。首先作者表示影象去雨滴的難度在於兩點:

The problem is intractable, since first the regions occluded by raindrops are not given. Second, the information about the background scene of the occluded regions is completely lost for most part.(一是因為雨滴的覆蓋範圍沒有給出,二是因為雨滴覆蓋區域的背景資訊損失太多。)

之後作者表示文章的最大貢獻在於將attention map引入模型中,這樣就能使得生成器generator能夠雨滴所在區域的結構,判別器discriminator能夠評估區域性連續性。

Our main idea is to inject visual attention into both the generative and discriminative networks. During the training, our visual attention learns about raindrop regions and their surroundings. Hence, by injecting this information, the generative network will pay more attention to the raindrop regions and the surrounding structures, and the discriminative network will be able to assess the local consistency of the restored regions. This injection of visual attention to both generative and discriminative networks is the main contribution of this paper.

那麼,如何生成attention map呢?作者利用ResNet和LSTM,以及少量卷積層來生成,並將這個結構命名為attentive-recurrent network. 同時,輸入影像可以表示為三部分:

                                                           I = (1-M)\odot B+R

即輸入影像(I)可以看作去掉雨滴掩模(M)的背景(B)和雨滴效應(R)的混合。且我們關注的資訊主要是背景區域,前景即雨滴區域往往是模糊的。因此我們的目的就是還原輸入影像(I)的背景(B),而雨滴掩模(M)我們可以用attention map來生成。

關於網路結構,先給出示意圖:

可以看到,在generator中,作者先使用recurrent network生成attention map,每一次生成attention map都使用了5個ResNet,和LSTM以及1個卷積層。然後作者使用了類似自編碼器的結構Contextual Autoencoder,其中有16個conv-relu層。對於判別器discriminator來說,如果影象中有任何不連續( inconsistency)的地方,都能夠很方便的用於判斷真假,因此,作者採用區域性判別器(local discrimnator)來處理。將輸入資料放入到CNN中提取特徵,並引入attention map以引導discriminator,最後再引入全連線層以判斷影象的真假。判別器一共包含7個卷積層和一個全連線層。

最後作者將自己的模型結果與其他模型進行了簡單對比,效果非常好:

三、attentive GAN所有檔案介紹

接下來的三、四、五節都是關於模型的實現過程,我主要參考了程式碼[2],並做了少量修改,當然作者也提供了原始碼,不過是pytorch版本的,有興趣的話可以參考[4],下面來逐一介紹所有檔案。

[4]https://github.com/rui1996/DeRaindrop

1. 所有檔案結構

所有檔案的結構為,其中需要自己準備的檔案用#####標記了出來:

-- attentive_GAN_model                          # 資料夾中存放了attentiveGAN的相關檔案
            |------ attentive_GAN_net.py
            |------ cnn_basenet.py
            |------ derain_drop_net.py
            |------ discriminative_net.py
            |------ tf_ssim.py
            |------ vgg16.py
-- data_provider                                # 資料夾中存放了讀取資料的檔案
            |------ data_provider.py
-- config                                       # 資料夾中存放了配置檔案
            |------ global_config.py
-- data2txt.py                                  # 將資料寫入到train.txt檔案中  
-- train_model.py                               # 訓練檔案 
-- test_model.py                                # 測試檔案
-- data
    |------ test_data                           # 測試資料
                |------ 0_rain.png
                |------ 1_rain.png
                |------ ......
    |------ training_data                       # 訓練資料 
                |------ data                    ##### 有雨滴的訓練影象
                        |------ 0_rain.png
                        |------ 1_rain.png
                        |------ ......
                |------ gt                      ##### 清晰的訓練影象
                        |------ 0_clear.png
                        |------ 1_clear.png
                        |------ ......
                |------ train.txt               ##### 準備好資料後,用data2txt生成的檔案 
    |------ vgg16.npy                           ##### 需要自己手動下載 

2. 資料準備

這裡我們需要準備的東西有3個:

(1)vgg16.npy檔案

原連結[3]中作者並沒有提供vgg16.npy的檔案,所以我就自己找了一份,然後傳到了自己的百度雲上,當然如果自己有的話也可以直接拿來用。下面直接給出連結:

百度雲地址:https://pan.baidu.com/s/13lZ1PEVTvpBt5l1-7qZaPQ

提取碼:hqnr

下載好之後,放到路徑'./data/'下即可。

(2)訓練及測試資料

如果你不想做訓練的話,可以自己隨意找幾張照片進行測試。不過為了這個小實驗的完整性,作者給出了原資料的下載地址,可參考:

https://drive.google.com/open?id=1e7R76s6vwUJxILOcAsthgDLPSnOrQ49K

訓練集一共861個影像對,測試集一共239個影像對。上述連結需要翻牆開啟,開啟後直接下載就可以了:

上述資料集翻牆下載起來比較慢,我將這部分資料下載好放到了自己的百度雲上,下面給出連結,不方便在google上下載的話可以用下面的連結:

百度雲地址:https://pan.baidu.com/s/1aXEr1Et10SDn5jRT-DISpg

提取碼:een3

下載好資料並解壓,按上述檔案格式放到相應的資料夾下即可。

(3)製作train.txt檔案

最後一步就是製作txt檔案了,參考程式碼中作者並沒有給出製作檔案的程式碼,自己就大概寫了一個。也就是data2txt.py檔案,後面會詳細介紹到。

製作txt檔案需要先放置好所有的資料,然後直接執行data2txt.py檔案,就可以在'./data/training_data/'下生成train.txt檔案。如果所有步驟都正確執行了,那麼train.txt檔案中的每一行都是相對應的一對影像,內容就是下面這樣的形式:

3. data_provider資料夾下的所有檔案

data_provider資料夾下面只有一個data_provider.py檔案,下面直接給出該檔案的程式碼:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : data_provider.py
import os.path as ops

import numpy as np
import cv2

from config import global_config

CFG = global_config.cfg


class DataSet(object):

    def __init__(self, dataset_info_file):
        self._gt_img_list, self._gt_label_list = self._init_dataset(dataset_info_file)
        self._random_dataset()
        self._next_batch_loop_count = 0

    def _init_dataset(self, dataset_info_file):

        gt_img_list = []
        gt_label_list = []

        assert ops.exists(dataset_info_file), '{:s} 不存在'.format(dataset_info_file)

        with open(dataset_info_file, 'r') as file:
            for _info in file:
                info_tmp = _info.strip(' ').split()

                gt_img_list.append(info_tmp[0])
                gt_label_list.append(info_tmp[1])
                print(gt_img_list[-1], gt_label_list[-1])
        # print(gt_img_list, gt_label_list)
        return gt_img_list, gt_label_list

    def _random_dataset(self):

        assert len(self._gt_img_list) == len(self._gt_label_list)

        random_idx = np.random.permutation(len(self._gt_img_list))
        new_gt_img_list = []
        new_gt_label_list = []

        for index in random_idx:
            new_gt_img_list.append(self._gt_img_list[index])
            new_gt_label_list.append(self._gt_label_list[index])

        self._gt_img_list = new_gt_img_list
        self._gt_label_list = new_gt_label_list


    def next_batch(self, batch_size):
        assert len(self._gt_label_list) == len(self._gt_img_list)

        idx_start = batch_size * self._next_batch_loop_count
        idx_end = batch_size * self._next_batch_loop_count + batch_size

        if idx_end > len(self._gt_label_list):
            self._random_dataset()
            self._next_batch_loop_count = 0
            return self.next_batch(batch_size)
        else:
            gt_img_list = self._gt_img_list[idx_start:idx_end]
            gt_label_list = self._gt_label_list[idx_start:idx_end]

            gt_imgs = []
            gt_labels = []
            mask_labels = []

            for index, gt_img_path in enumerate(gt_img_list):
                gt_image = cv2.imread(gt_img_path, cv2.IMREAD_COLOR)
                label_image = cv2.imread(gt_label_list[index], cv2.IMREAD_COLOR)

                gt_image = cv2.resize(gt_image, (CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT))
                label_image = cv2.resize(label_image, (CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT))

                diff_image = np.abs(np.array(cv2.cvtColor(gt_image, cv2.COLOR_BGR2GRAY), np.float32) -
                                    np.array(cv2.cvtColor(label_image, cv2.COLOR_BGR2GRAY), np.float32))

                mask_image = np.zeros(diff_image.shape, np.float32)

                mask_image[np.where(diff_image >= 30)] = 1

                gt_image = np.divide(gt_image, 127.5) - 1
                label_image = np.divide(label_image, 127.5) - 1

                gt_imgs.append(gt_image)
                gt_labels.append(label_image)
                mask_labels.append(mask_image)

            self._next_batch_loop_count += 1
            return gt_imgs, gt_labels, mask_labels

4. attentive_GAN_model資料夾下的所有檔案

attentive_GAN_model資料夾下的檔案比較多,自己大致看了一下,原始碼寫的挺好,我幾乎沒有做任何改動,只是刪除掉了if main後面的程式碼。下面直接給出每個檔案的程式碼:

attentive_GAN_net.py中的主要程式碼:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : attentive_gan_net.py

import tensorflow as tf

from attentive_gan_model import cnn_basenet
from attentive_gan_model import vgg16


class GenerativeNet(cnn_basenet.CNNBaseModel):

    def __init__(self, phase):

        super(GenerativeNet, self).__init__()
        self._vgg_extractor = vgg16.VGG16Encoder(phase='test')
        self._train_phase = tf.constant('train', dtype=tf.string)
        self._test_phase = tf.constant('test', dtype=tf.string)
        self._phase = phase
        self._is_training = self._init_phase()

    def _init_phase(self):

        return tf.equal(self._phase, self._train_phase)


    def _residual_block(self, input_tensor, name):

        output = None
        with tf.variable_scope(name):
            for i in range(5):
                if i == 0:
                    conv_1 = self.conv2d(inputdata=input_tensor,
                                         out_channel=32,
                                         kernel_size=3,
                                         padding='SAME',
                                         stride=1,
                                         use_bias=False,
                                         name='block_{:d}_conv_1'.format(i))
                    relu_1 = self.lrelu(inputdata=conv_1, name='block_{:d}_relu_1'.format(i + 1))
                    output = relu_1
                    input_tensor = output
                else:
                    conv_1 = self.conv2d(inputdata=input_tensor,
                                         out_channel=32,
                                         kernel_size=1,
                                         padding='SAME',
                                         stride=1,
                                         use_bias=False,
                                         name='block_{:d}_conv_1'.format(i))
                    relu_1 = self.lrelu(inputdata=conv_1, name='block_{:d}_conv_1'.format(i + 1))
                    conv_2 = self.conv2d(inputdata=relu_1,
                                         out_channel=32,
                                         kernel_size=1,
                                         padding='SAME',
                                         stride=1,
                                         use_bias=False,
                                         name='block_{:d}_conv_2'.format(i))
                    relu_2 = self.lrelu(inputdata=conv_2, name='block_{:d}_conv_2'.format(i + 1))

                    output = self.lrelu(inputdata=tf.add(relu_2, input_tensor),
                                        name='block_{:d}_add'.format(i))
                    input_tensor = output

        return output

    def _conv_lstm(self, input_tensor, input_cell_state, name):

        with tf.variable_scope(name):
            conv_i = self.conv2d(inputdata=input_tensor, out_channel=32, kernel_size=3, padding='SAME',
                                 stride=1, use_bias=False, name='conv_i')
            sigmoid_i = self.sigmoid(inputdata=conv_i, name='sigmoid_i')

            conv_f = self.conv2d(inputdata=input_tensor, out_channel=32, kernel_size=3, padding='SAME',
                                 stride=1, use_bias=False, name='conv_f')
            sigmoid_f = self.sigmoid(inputdata=conv_f, name='sigmoid_f')

            cell_state = sigmoid_f * input_cell_state + \
                         sigmoid_i * tf.nn.tanh(self.conv2d(inputdata=input_tensor,
                                                            out_channel=32,
                                                            kernel_size=3,
                                                            padding='SAME',
                                                            stride=1,
                                                            use_bias=False,
                                                            name='conv_c'))
            conv_o = self.conv2d(inputdata=input_tensor, out_channel=32, kernel_size=3, padding='SAME',
                                 stride=1, use_bias=False, name='conv_o')
            sigmoid_o = self.sigmoid(inputdata=conv_o, name='sigmoid_o')

            lstm_feats = sigmoid_o * tf.nn.tanh(cell_state)

            attention_map = self.conv2d(inputdata=lstm_feats, out_channel=1, kernel_size=3, padding='SAME',
                                        stride=1, use_bias=False, name='attention_map')
            attention_map = self.sigmoid(inputdata=attention_map)

            ret = {
                'attention_map': attention_map,
                'cell_state': cell_state,
                'lstm_feats': lstm_feats
            }

            return ret

    def build_attentive_rnn(self, input_tensor, name):

        [batch_size, tensor_h, tensor_w, _] = input_tensor.get_shape().as_list()
        with tf.variable_scope(name):
            init_attention_map = tf.constant(0.5, dtype=tf.float32,
                                             shape=[batch_size, tensor_h, tensor_w, 1])
            init_cell_state = tf.constant(0.0, dtype=tf.float32,
                                          shape=[batch_size, tensor_h, tensor_w, 32])
            init_lstm_feats = tf.constant(0.0, dtype=tf.float32,
                                          shape=[batch_size, tensor_h, tensor_w, 32])

            attention_map_list = []

            for i in range(4):
                attention_input = tf.concat((input_tensor, init_attention_map), axis=-1)
                conv_feats = self._residual_block(input_tensor=attention_input,
                                                  name='residual_block_{:d}'.format(i + 1))
                lstm_ret = self._conv_lstm(input_tensor=conv_feats,
                                           input_cell_state=init_cell_state,
                                           name='conv_lstm_block_{:d}'.format(i + 1))
                init_attention_map = lstm_ret['attention_map']
                init_cell_state = lstm_ret['cell_state']
                init_lstm_feats = lstm_ret['lstm_feats']

                attention_map_list.append(lstm_ret['attention_map'])

        ret = {
            'final_attention_map': init_attention_map,
            'final_lstm_feats': init_lstm_feats,
            'attention_map_list': attention_map_list
        }

        return ret

    def compute_attentive_rnn_loss(self, input_tensor, label_tensor, name):

        with tf.variable_scope(name):
            inference_ret = self.build_attentive_rnn(input_tensor=input_tensor,
                                                     name='attentive_inference')
            loss = tf.constant(0.0, tf.float32)
            n = len(inference_ret['attention_map_list'])
            for index, attention_map in enumerate(inference_ret['attention_map_list']):
                mse_loss = tf.pow(0.8, n - index + 1) * \
                           tf.losses.mean_squared_error(labels=label_tensor,
                                                        predictions=attention_map)
                loss = tf.add(loss, mse_loss)

        return loss, inference_ret['final_attention_map']

    def build_autoencoder(self, input_tensor, name):

        with tf.variable_scope(name):
            conv_1 = self.conv2d(inputdata=input_tensor, out_channel=64, kernel_size=5,
                                 padding='SAME',
                                 stride=1, use_bias=False, name='conv_1')
            relu_1 = self.lrelu(inputdata=conv_1, name='relu_1')

            conv_2 = self.conv2d(inputdata=relu_1, out_channel=128, kernel_size=3,
                                 padding='SAME',
                                 stride=2, use_bias=False, name='conv_2')
            relu_2 = self.lrelu(inputdata=conv_2, name='relu_2')

            conv_3 = self.conv2d(inputdata=relu_2, out_channel=128, kernel_size=3,
                                 padding='SAME',
                                 stride=1, use_bias=False, name='conv_3')
            relu_3 = self.lrelu(inputdata=conv_3, name='relu_3')

            conv_4 = self.conv2d(inputdata=relu_3, out_channel=128, kernel_size=3,
                                 padding='SAME',
                                 stride=2, use_bias=False, name='conv_4')
            relu_4 = self.lrelu(inputdata=conv_4, name='relu_4')

            conv_5 = self.conv2d(inputdata=relu_4, out_channel=256, kernel_size=3,
                                 padding='SAME',
                                 stride=1, use_bias=False, name='conv_5')
            relu_5 = self.lrelu(inputdata=conv_5, name='relu_5')

            conv_6 = self.conv2d(inputdata=relu_5, out_channel=256, kernel_size=3,
                                 padding='SAME',
                                 stride=1, use_bias=False, name='conv_6')
            relu_6 = self.lrelu(inputdata=conv_6, name='relu_6')

            dia_conv1 = self.dilation_conv(input_tensor=relu_6, k_size=3, out_dims=256, rate=2,
                                           padding='SAME', use_bias=False, name='dia_conv_1')
            relu_7 = self.lrelu(dia_conv1, name='relu_7')

            dia_conv2 = self.dilation_conv(input_tensor=relu_7, k_size=3, out_dims=256, rate=4,
                                           padding='SAME', use_bias=False, name='dia_conv_2')
            relu_8 = self.lrelu(dia_conv2, name='relu_8')

            dia_conv3 = self.dilation_conv(input_tensor=relu_8, k_size=3, out_dims=256, rate=8,
                                           padding='SAME', use_bias=False, name='dia_conv_3')
            relu_9 = self.lrelu(dia_conv3, name='relu_9')

            dia_conv4 = self.dilation_conv(input_tensor=relu_9, k_size=3, out_dims=256, rate=16,
                                           padding='SAME', use_bias=False, name='dia_conv_4')
            relu_10 = self.lrelu(dia_conv4, name='relu_10')

            conv_7 = self.conv2d(inputdata=relu_10, out_channel=256, kernel_size=3,
                                 padding='SAME', stride=1, use_bias=False,
                                 name='conv_7')
            relu_11 = self.lrelu(inputdata=conv_7, name='relu_11')

            conv_8 = self.conv2d(inputdata=relu_11, out_channel=256, kernel_size=3,
                                 padding='SAME', stride=1, use_bias=False,
                                 name='conv_8')
            relu_12 = self.lrelu(inputdata=conv_8, name='relu_12')

            deconv_1 = self.deconv2d(inputdata=relu_12, out_channel=128, kernel_size=4,
                                     stride=2, padding='SAME', use_bias=False, name='deconv_1')
            avg_pool_1 = self.avgpooling(inputdata=deconv_1, kernel_size=2, stride=1, padding='SAME',
                                         name='avg_pool_1')
            relu_13 = self.lrelu(inputdata=avg_pool_1, name='relu_13')

            conv_9 = self.conv2d(inputdata=tf.add(relu_13, relu_3), out_channel=128, kernel_size=3,
                                 padding='SAME', stride=1, use_bias=False,
                                 name='conv_9')
            relu_14 = self.lrelu(inputdata=conv_9, name='relu_14')

            deconv_2 = self.deconv2d(inputdata=relu_14, out_channel=64, kernel_size=4,
                                     stride=2, padding='SAME', use_bias=False, name='deconv_2')
            avg_pool_2 = self.avgpooling(inputdata=deconv_2, kernel_size=2, stride=1, padding='SAME',
                                         name='avg_pool_2')
            relu_15 = self.lrelu(inputdata=avg_pool_2, name='relu_15')

            conv_10 = self.conv2d(inputdata=tf.add(relu_15, relu_1), out_channel=32, kernel_size=3,
                                  padding='SAME', stride=1, use_bias=False,
                                  name='conv_10')
            relu_16 = self.lrelu(inputdata=conv_10, name='relu_16')

            skip_output_1 = self.conv2d(inputdata=relu_12, out_channel=3, kernel_size=3,
                                        padding='SAME', stride=1, use_bias=False,
                                        name='skip_ouput_1')

            skip_output_2 = self.conv2d(inputdata=relu_14, out_channel=3, kernel_size=3,
                                        padding='SAME', stride=1, use_bias=False,
                                        name='skip_output_2')

            skip_output_3 = self.conv2d(inputdata=relu_16, out_channel=3, kernel_size=3,
                                        padding='SAME', stride=1, use_bias=False,
                                        name='skip_output_3')

            # 傳統GAN輸出層都使用tanh函式啟用
            skip_output_3 = tf.nn.tanh(skip_output_3, name='skip_output_3_tanh')

            ret = {
                'skip_1': skip_output_1,
                'skip_2': skip_output_2,
                'skip_3': skip_output_3
            }

        return ret

    def compute_autoencoder_loss(self, input_tensor, label_tensor, name):

        [_, ori_height, ori_width, _] = label_tensor.get_shape().as_list()
        label_tensor_ori = label_tensor
        label_tensor_resize_2 = tf.image.resize_bilinear(images=label_tensor,
                                                         size=(int(ori_height / 2), int(ori_width / 2)))
        label_tensor_resize_4 = tf.image.resize_bilinear(images=label_tensor,
                                                         size=(int(ori_height / 4), int(ori_width / 4)))
        label_list = [label_tensor_resize_4, label_tensor_resize_2, label_tensor_ori]
        lambda_i = [0.6, 0.8, 1.0]
        # 計算lm_loss(見公式(5))
        lm_loss = tf.constant(0.0, tf.float32)
        with tf.variable_scope(name):
            inference_ret = self.build_autoencoder(input_tensor=input_tensor, name='autoencoder_inference')
            output_list = [inference_ret['skip_1'], inference_ret['skip_2'], inference_ret['skip_3']]
            for index, output in enumerate(output_list):
                mse_loss = tf.losses.mean_squared_error(output, label_list[index]) * lambda_i[index]
                lm_loss = tf.add(lm_loss, mse_loss)

            # 計算lp_loss(見公式(6))
            src_vgg_feats = self._vgg_extractor.extract_feats(input_tensor=label_tensor,
                                                              name='vgg_feats',
                                                              reuse=False)
            pred_vgg_feats = self._vgg_extractor.extract_feats(input_tensor=output_list[-1],
                                                               name='vgg_feats',
                                                               reuse=True)

            lp_losses = []
            for index, feats in enumerate(src_vgg_feats):
                lp_losses.append(tf.losses.mean_squared_error(src_vgg_feats[index], pred_vgg_feats[index]))
            lp_loss = tf.reduce_mean(lp_losses)

            loss = tf.add(lm_loss, lp_loss)

        return loss, inference_ret['skip_3']

cnn_basenet.py中的主要程式碼:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://github.com/TJCVRS
#             @File    : cnn_basenet.py

import tensorflow as tf
import numpy as np


class CNNBaseModel(object):
    """
    Base model for other specific cnn ctpn_models
    """

    def __init__(self):
        pass

    @staticmethod
    def conv2d(inputdata, out_channel, kernel_size, padding='SAME',
               stride=1, w_init=None, b_init=None,
               split=1, use_bias=True, data_format='NHWC', name=None):
        """
        Packing the tensorflow conv2d function.
        :param name: op name
        :param inputdata: A 4D tensorflow tensor which ust have known number of channels, but can have other
        unknown dimensions.
        :param out_channel: number of output channel.
        :param kernel_size: int so only support square kernel convolution
        :param padding: 'VALID' or 'SAME'
        :param stride: int so only support square stride
        :param w_init: initializer for convolution weights
        :param b_init: initializer for bias
        :param split: split channels as used in Alexnet mainly group for GPU memory save.
        :param use_bias:  whether to use bias.
        :param data_format: default set to NHWC according tensorflow
        :return: tf.Tensor named ``output``
        """
        with tf.variable_scope(name):
            in_shape = inputdata.get_shape().as_list()
            channel_axis = 3 if data_format == 'NHWC' else 1
            in_channel = in_shape[channel_axis]
            assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!"
            assert in_channel % split == 0
            assert out_channel % split == 0

            padding = padding.upper()

            if isinstance(kernel_size, list):
                filter_shape = [kernel_size[0], kernel_size[1]] + [in_channel / split, out_channel]
            else:
                filter_shape = [kernel_size, kernel_size] + [in_channel / split, out_channel]

            if isinstance(stride, list):
                strides = [1, stride[0], stride[1], 1] if data_format == 'NHWC' \
                    else [1, 1, stride[0], stride[1]]
            else:
                strides = [1, stride, stride, 1] if data_format == 'NHWC' \
                    else [1, 1, stride, stride]

            if w_init is None:
                w_init = tf.contrib.layers.variance_scaling_initializer()
            if b_init is None:
                b_init = tf.constant_initializer()

            w = tf.get_variable('W', filter_shape, initializer=w_init)
            b = None

            if use_bias:
                b = tf.get_variable('b', [out_channel], initializer=b_init)

            if split == 1:
                conv = tf.nn.conv2d(inputdata, w, strides, padding, data_format=data_format)
            else:
                inputs = tf.split(inputdata, split, channel_axis)
                kernels = tf.split(w, split, 3)
                outputs = [tf.nn.conv2d(i, k, strides, padding, data_format=data_format)
                           for i, k in zip(inputs, kernels)]
                conv = tf.concat(outputs, channel_axis)

            ret = tf.identity(tf.nn.bias_add(conv, b, data_format=data_format)
                              if use_bias else conv, name=name)

        return ret

    @staticmethod
    def relu(inputdata, name=None):

        return tf.nn.relu(features=inputdata, name=name)

    @staticmethod
    def sigmoid(inputdata, name=None):
 
        return tf.nn.sigmoid(x=inputdata, name=name)

    @staticmethod
    def maxpooling(inputdata, kernel_size, stride=None, padding='VALID',
                   data_format='NHWC', name=None):

        padding = padding.upper()

        if stride is None:
            stride = kernel_size

        if isinstance(kernel_size, list):
            kernel = [1, kernel_size[0], kernel_size[1], 1] if data_format == 'NHWC' else \
                [1, 1, kernel_size[0], kernel_size[1]]
        else:
            kernel = [1, kernel_size, kernel_size, 1] if data_format == 'NHWC' \
                else [1, 1, kernel_size, kernel_size]

        if isinstance(stride, list):
            strides = [1, stride[0], stride[1], 1] if data_format == 'NHWC' \
                else [1, 1, stride[0], stride[1]]
        else:
            strides = [1, stride, stride, 1] if data_format == 'NHWC' \
                else [1, 1, stride, stride]

        return tf.nn.max_pool(value=inputdata, ksize=kernel, strides=strides, padding=padding,
                              data_format=data_format, name=name)

    @staticmethod
    def avgpooling(inputdata, kernel_size, stride=None, padding='VALID',
                   data_format='NHWC', name=None):

        if stride is None:
            stride = kernel_size

        kernel = [1, kernel_size, kernel_size, 1] if data_format == 'NHWC' \
            else [1, 1, kernel_size, kernel_size]

        strides = [1, stride, stride, 1] if data_format == 'NHWC' else [1, 1, stride, stride]

        return tf.nn.avg_pool(value=inputdata, ksize=kernel, strides=strides, padding=padding,
                              data_format=data_format, name=name)

    @staticmethod
    def globalavgpooling(inputdata, data_format='NHWC', name=None):

        assert inputdata.shape.ndims == 4
        assert data_format in ['NHWC', 'NCHW']

        axis = [1, 2] if data_format == 'NHWC' else [2, 3]

        return tf.reduce_mean(input_tensor=inputdata, axis=axis, name=name)

    @staticmethod
    def layernorm(inputdata, epsilon=1e-5, use_bias=True, use_scale=True,
                  data_format='NHWC', name=None):
        """
        :param name:
        :param inputdata:
        :param epsilon: epsilon to avoid divide-by-zero.
        :param use_bias: whether to use the extra affine transformation or not.
        :param use_scale: whether to use the extra affine transformation or not.
        :param data_format:
        :return:
        """
        shape = inputdata.get_shape().as_list()
        ndims = len(shape)
        assert ndims in [2, 4]

        mean, var = tf.nn.moments(inputdata, list(range(1, len(shape))), keep_dims=True)

        if data_format == 'NCHW':
            channnel = shape[1]
            new_shape = [1, channnel, 1, 1]
        else:
            channnel = shape[-1]
            new_shape = [1, 1, 1, channnel]
        if ndims == 2:
            new_shape = [1, channnel]

        if use_bias:
            beta = tf.get_variable('beta', [channnel], initializer=tf.constant_initializer())
            beta = tf.reshape(beta, new_shape)
        else:
            beta = tf.zeros([1] * ndims, name='beta')
        if use_scale:
            gamma = tf.get_variable('gamma', [channnel], initializer=tf.constant_initializer(1.0))
            gamma = tf.reshape(gamma, new_shape)
        else:
            gamma = tf.ones([1] * ndims, name='gamma')

        return tf.nn.batch_normalization(inputdata, mean, var, beta, gamma, epsilon, name=name)

    @staticmethod
    def instancenorm(inputdata, epsilon=1e-5, data_format='NHWC', use_affine=True, name=None):
        shape = inputdata.get_shape().as_list()
        if len(shape) != 4:
            raise ValueError("Input data of instancebn layer has to be 4D tensor")

        if data_format == 'NHWC':
            axis = [1, 2]
            ch = shape[3]
            new_shape = [1, 1, 1, ch]
        else:
            axis = [2, 3]
            ch = shape[1]
            new_shape = [1, ch, 1, 1]
        if ch is None:
            raise ValueError("Input of instancebn require known channel!")

        mean, var = tf.nn.moments(inputdata, axis, keep_dims=True)

        if not use_affine:
            return tf.divide(inputdata - mean, tf.sqrt(var + epsilon), name='output')

        beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
        beta = tf.reshape(beta, new_shape)
        gamma = tf.get_variable('gamma', [ch], initializer=tf.constant_initializer(1.0))
        gamma = tf.reshape(gamma, new_shape)
        return tf.nn.batch_normalization(inputdata, mean, var, beta, gamma, epsilon, name=name)

    @staticmethod
    def dropout(inputdata, keep_prob, noise_shape=None, name=None):
        return tf.nn.dropout(inputdata, keep_prob=keep_prob, noise_shape=noise_shape, name=name)

    @staticmethod
    def fullyconnect(inputdata, out_dim, w_init=None, b_init=None,
                     use_bias=True, name=None):
        """
        Fully-Connected layer, takes a N>1D tensor and returns a 2D tensor.
        It is an equivalent of `tf.layers.dense` except for naming conventions.

        :param inputdata:  a tensor to be flattened except for the first dimension.
        :param out_dim: output dimension
        :param w_init: initializer for w. Defaults to `variance_scaling_initializer`.
        :param b_init: initializer for b. Defaults to zero
        :param use_bias: whether to use bias.
        :param name:
        :return: tf.Tensor: a NC tensor named ``output`` with attribute `variables`.
        """
        shape = inputdata.get_shape().as_list()[1:]
        if None not in shape:
            inputdata = tf.reshape(inputdata, [-1, int(np.prod(shape))])
        else:
            inputdata = tf.reshape(inputdata, tf.stack([tf.shape(inputdata)[0], -1]))

        if w_init is None:
            w_init = tf.contrib.layers.variance_scaling_initializer()
        if b_init is None:
            b_init = tf.constant_initializer()

        ret = tf.layers.dense(inputs=inputdata, activation=lambda x: tf.identity(x, name='output'),
                              use_bias=use_bias, name=name,
                              kernel_initializer=w_init, bias_initializer=b_init,
                              trainable=True, units=out_dim)
        return ret

    @staticmethod
    def layerbn(inputdata, is_training, name):
        with tf.variable_scope(name):
            return tf.layers.batch_normalization(inputs=inputdata, training=is_training)

    @staticmethod
    def squeeze(inputdata, axis=None, name=None):
        return tf.squeeze(input=inputdata, axis=axis, name=name)

    @staticmethod
    def deconv2d(inputdata, out_channel, kernel_size, padding='SAME',
                 stride=1, w_init=None, b_init=None,
                 use_bias=True, activation=None, data_format='channels_last',
                 trainable=True, name=None):
        """
        Packing the tensorflow conv2d function.
        :param name: op name
        :param inputdata: A 4D tensorflow tensor which ust have known number of channels, but can have other
        unknown dimensions.
        :param out_channel: number of output channel.
        :param kernel_size: int so only support square kernel convolution
        :param padding: 'VALID' or 'SAME'
        :param stride: int so only support square stride
        :param w_init: initializer for convolution weights
        :param b_init: initializer for bias
        :param activation: whether to apply a activation func to deconv result
        :param use_bias:  whether to use bias.
        :param data_format: default set to NHWC according tensorflow
        :return: tf.Tensor named ``output``
        """
        with tf.variable_scope(name):
            in_shape = inputdata.get_shape().as_list()
            channel_axis = 3 if data_format == 'channels_last' else 1
            in_channel = in_shape[channel_axis]
            assert in_channel is not None, "[Deconv2D] Input cannot have unknown channel!"

            padding = padding.upper()

            if w_init is None:
                w_init = tf.contrib.layers.variance_scaling_initializer()
            if b_init is None:
                b_init = tf.constant_initializer()

            ret = tf.layers.conv2d_transpose(inputs=inputdata, filters=out_channel,
                                             kernel_size=kernel_size,
                                             strides=stride, padding=padding,
                                             data_format=data_format,
                                             activation=activation, use_bias=use_bias,
                                             kernel_initializer=w_init,
                                             bias_initializer=b_init, trainable=trainable,
                                             name=name)
        return ret

    @staticmethod
    def dilation_conv(input_tensor, k_size, out_dims, rate, padding='SAME',
                      w_init=None, b_init=None, use_bias=False, name=None):

        with tf.variable_scope(name):
            in_shape = input_tensor.get_shape().as_list()
            in_channel = in_shape[3]
            assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!"

            padding = padding.upper()

            if isinstance(k_size, list):
                filter_shape = [k_size[0], k_size[1]] + [in_channel, out_dims]
            else:
                filter_shape = [k_size, k_size] + [in_channel, out_dims]

            if w_init is None:
                w_init = tf.contrib.layers.variance_scaling_initializer()
            if b_init is None:
                b_init = tf.constant_initializer()

            w = tf.get_variable('W', filter_shape, initializer=w_init)
            b = None

            if use_bias:
                b = tf.get_variable('b', [out_dims], initializer=b_init)

            conv = tf.nn.atrous_conv2d(value=input_tensor, filters=w, rate=rate,
                                       padding=padding, name='dilation_conv')

            if use_bias:
                ret = tf.add(conv, b)
            else:
                ret = conv

        return ret

    @staticmethod
    def spatial_dropout(input_tensor, keep_prob, is_training, name, seed=1234):
        tf.set_random_seed(seed=seed)

        def f1():
            with tf.variable_scope(name):
                return input_tensor

        def f2():
            with tf.variable_scope(name):
                num_feature_maps = [tf.shape(input_tensor)[0], tf.shape(input_tensor)[3]]

                random_tensor = keep_prob
                random_tensor += tf.random_uniform(num_feature_maps,
                                                   seed=seed,
                                                   dtype=input_tensor.dtype)

                binary_tensor = tf.floor(random_tensor)

                binary_tensor = tf.reshape(binary_tensor,
                                           [-1, 1, 1, tf.shape(input_tensor)[3]])
                ret = input_tensor * binary_tensor
                return ret

        output = tf.cond(is_training, f2, f1)
        return output

    @staticmethod
    def lrelu(inputdata, name, alpha=0.2):
        with tf.variable_scope(name):
            return tf.nn.relu(inputdata) - alpha * tf.nn.relu(-inputdata)

derain_drop_net.py中的主要程式碼:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : derain_drop_net.py

import tensorflow as tf

from attentive_gan_model import attentive_gan_net
from attentive_gan_model import discriminative_net


class DeRainNet(object):

    def __init__(self, phase):

        self._phase = phase
        self._attentive_gan = attentive_gan_net.GenerativeNet(self._phase)
        self._discriminator = discriminative_net.DiscriminativeNet(self._phase)

    def compute_loss(self, input_tensor, gt_label_tensor, mask_label_tensor, name):

        with tf.variable_scope(name):

            # 計算attentive rnn loss
            attentive_rnn_loss, attentive_rnn_output = self._attentive_gan.compute_attentive_rnn_loss(
                input_tensor=input_tensor,
                label_tensor=mask_label_tensor,
                name='attentive_rnn_loss')

            auto_encoder_input = tf.concat((attentive_rnn_output, input_tensor), axis=-1)

            auto_encoder_loss, auto_encoder_output = self._attentive_gan.compute_autoencoder_loss(
                input_tensor=auto_encoder_input,
                label_tensor=gt_label_tensor,
                name='attentive_autoencoder_loss'
            )

            gan_loss = tf.add(attentive_rnn_loss, auto_encoder_loss)

            discriminative_inference, discriminative_loss = self._discriminator.compute_loss(
                input_tensor=auto_encoder_output,
                label_tensor=gt_label_tensor,
                attention_map=attentive_rnn_output,
                name='discriminative_loss')

            l_gan = tf.reduce_mean(tf.log(tf.subtract(tf.constant(1.0), discriminative_inference)) * 0.01)

            gan_loss = tf.add(gan_loss, l_gan)

            return gan_loss, discriminative_loss, auto_encoder_output

    # 用於測試
    def build(self, input_tensor, name):

        with tf.variable_scope(name):

            attentive_rnn_out = self._attentive_gan.build_attentive_rnn(
                input_tensor=input_tensor,
                name='attentive_rnn_loss/attentive_inference'
            )

            attentive_autoencoder_input = tf.concat((attentive_rnn_out['final_attention_map'],
                                                     input_tensor), axis=-1)

            output = self._attentive_gan.build_autoencoder(
                input_tensor=attentive_autoencoder_input,
                name='attentive_autoencoder_loss/autoencoder_inference'
            )

            return output['skip_3'], attentive_rnn_out['attention_map_list']

discriminative_net.py中的主要程式碼:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : discriminative_net.py

import tensorflow as tf

from attentive_gan_model import cnn_basenet


class DiscriminativeNet(cnn_basenet.CNNBaseModel):

    def __init__(self, phase):

        super(DiscriminativeNet, self).__init__()
        self._train_phase = tf.constant('train', dtype=tf.string)
        self._test_phase = tf.constant('test', dtype=tf.string)
        self._phase = phase
        self._is_training = self._init_phase()

    def _init_phase(self):

        return tf.equal(self._phase, self._train_phase)

    def _conv_stage(self, input_tensor, k_size, stride, out_dims, name):

        with tf.variable_scope(name):
            conv = self.conv2d(inputdata=input_tensor, out_channel=out_dims, kernel_size=k_size,
                               padding='SAME', stride=stride, use_bias=False, name='conv')

            relu = self.lrelu(conv, name='relu')

        return relu

    def build(self, input_tensor, name, reuse=False):

        with tf.variable_scope(name, reuse=reuse):
            conv_stage_1 = self._conv_stage(input_tensor=input_tensor, k_size=5,
                                            stride=1, out_dims=8,
                                            name='conv_stage_1')
            conv_stage_2 = self._conv_stage(input_tensor=conv_stage_1, k_size=5,
                                            stride=1, out_dims=16, name='conv_stage_2')
            conv_stage_3 = self._conv_stage(input_tensor=conv_stage_2, k_size=5,
                                            stride=1, out_dims=32, name='conv_stage_3')
            conv_stage_4 = self._conv_stage(input_tensor=conv_stage_3, k_size=5,
                                            stride=1, out_dims=64, name='conv_stage_4')
            conv_stage_5 = self._conv_stage(input_tensor=conv_stage_4, k_size=5,
                                            stride=1, out_dims=128, name='conv_stage_5')
            conv_stage_6 = self._conv_stage(input_tensor=conv_stage_5, k_size=5,
                                            stride=1, out_dims=128, name='conv_stage_6')
            attention_map = self.conv2d(inputdata=conv_stage_6, out_channel=1, kernel_size=5,
                                        padding='SAME', stride=1, use_bias=False, name='attention_map')
            conv_stage_7 = self._conv_stage(input_tensor=attention_map * conv_stage_6, k_size=5,
                                            stride=4, out_dims=64, name='conv_stage_7')
            conv_stage_8 = self._conv_stage(input_tensor=conv_stage_7, k_size=5,
                                            stride=4, out_dims=64, name='conv_stage_8')
            conv_stage_9 = self._conv_stage(input_tensor=conv_stage_8, k_size=5,
                                            stride=4, out_dims=32, name='conv_stage_9')
            fc_1 = self.fullyconnect(inputdata=conv_stage_9, out_dim=1024, use_bias=False, name='fc_1')
            fc_2 = self.fullyconnect(inputdata=fc_1, out_dim=1, use_bias=False, name='fc_2')
            fc_out = self.sigmoid(inputdata=fc_2, name='fc_out')

            fc_out = tf.where(tf.not_equal(fc_out, 1.0), fc_out, fc_out - 0.0000001)
            fc_out = tf.where(tf.not_equal(fc_out, 0.0), fc_out, fc_out + 0.0000001)

            return fc_out, attention_map, fc_2

    def compute_loss(self, input_tensor, label_tensor, attention_map, name):

        with tf.variable_scope(name):
            [batch_size, image_h, image_w, _] = input_tensor.get_shape().as_list()

            # 論文裡的O
            zeros_mask = tf.zeros(shape=[batch_size, image_h, image_w, 1],
                                  dtype=tf.float32, name='O')
            fc_out_o, attention_mask_o, fc2_o = self.build(
                input_tensor=input_tensor, name='discriminative_inference')
            fc_out_r, attention_mask_r, fc2_r = self.build(
                input_tensor=label_tensor, name='discriminative_inference', reuse=True)

            l_map = tf.losses.mean_squared_error(attention_map, attention_mask_o) + \
                    tf.losses.mean_squared_error(attention_mask_r, zeros_mask)

            entropy_loss = -tf.log(fc_out_r) - tf.log(-tf.subtract(fc_out_o, tf.constant(1.0, tf.float32)))
            entropy_loss = tf.reduce_mean(entropy_loss)

            loss = entropy_loss + 0.05 * l_map

            return fc_out_o, loss

tf_ssim.py中的主要程式碼:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : tf_ssim.py

import tensorflow as tf
import numpy as np


class SsimComputer(object):

    def __init__(self):
        pass

    @staticmethod
    def _tf_fspecial_gauss(size, sigma):

        x_data, y_data = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1]

        x_data = np.expand_dims(x_data, axis=-1)
        x_data = np.expand_dims(x_data, axis=-1)

        y_data = np.expand_dims(y_data, axis=-1)
        y_data = np.expand_dims(y_data, axis=-1)

        x = tf.constant(x_data, dtype=tf.float32)
        y = tf.constant(y_data, dtype=tf.float32)

        g = tf.exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2)))
        return g / tf.reduce_sum(g)

    def compute_ssim(self, img1, img2, cs_map=False, mean_metric=True, size=11, sigma=1.5):

        assert img1.get_shape().as_list()[-1] == 1, 'Image must be gray scale'
        assert img2.get_shape().as_list()[-1] == 1, 'Image must be gray scale'

        window = self._tf_fspecial_gauss(size, sigma)  # window shape [size, size]
        K1 = 0.01  # origin parameter in paper
        K2 = 0.03  # origin parameter in paper
        L = 1  # depth of image (255 in case the image has a differnt scale)
        C1 = (K1 * L) ** 2
        C2 = (K2 * L) ** 2
        mu1 = tf.nn.conv2d(img1, window, strides=[1, 1, 1, 1], padding='VALID')
        mu2 = tf.nn.conv2d(img2, window, strides=[1, 1, 1, 1], padding='VALID')
        mu1_sq = mu1 * mu1
        mu2_sq = mu2 * mu2
        mu1_mu2 = mu1 * mu2
        sigma1_sq = tf.nn.conv2d(img1 * img1, window, strides=[1, 1, 1, 1], padding='VALID') - mu1_sq
        sigma2_sq = tf.nn.conv2d(img2 * img2, window, strides=[1, 1, 1, 1], padding='VALID') - mu2_sq
        sigma12 = tf.nn.conv2d(img1 * img2, window, strides=[1, 1, 1, 1], padding='VALID') - mu1_mu2
        if cs_map:
            value = (((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                                  (sigma1_sq + sigma2_sq + C2)),
                     (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2))
        else:
            value = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                                 (sigma1_sq + sigma2_sq + C2))

        if mean_metric:
            value = tf.reduce_mean(value)
        return value

vgg16.py中的主要程式碼:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : vgg16.py

import tensorflow as tf

from attentive_gan_model import cnn_basenet


class VGG16Encoder(cnn_basenet.CNNBaseModel):

    def __init__(self, phase):

        super(VGG16Encoder, self).__init__()
        self._train_phase = tf.constant('train', dtype=tf.string)
        self._test_phase = tf.constant('test', dtype=tf.string)
        self._phase = phase
        self._is_training = self._init_phase()
        print('VGG16 Network init complete')

    def _init_phase(self):

        return tf.equal(self._phase, self._train_phase)

    def _conv_stage(self, input_tensor, k_size, out_dims, name,
                    stride=1, pad='SAME', reuse=False):

        with tf.variable_scope(name, reuse=reuse):
            conv = self.conv2d(inputdata=input_tensor, out_channel=out_dims,
                               kernel_size=k_size, stride=stride,
                               use_bias=False, padding=pad, name='conv')
            relu = self.relu(inputdata=conv, name='relu')

            return relu

    def _fc_stage(self, input_tensor, out_dims, name, use_bias=False, reuse=False):

        with tf.variable_scope(name, reuse=reuse):
            fc = self.fullyconnect(inputdata=input_tensor, out_dim=out_dims, use_bias=use_bias,
                                   name='fc')
            relu = self.relu(inputdata=fc, name='relu')

        return relu

    def extract_feats(self, input_tensor, name, reuse=False):
        with tf.variable_scope(name, reuse=reuse):
            # conv stage 1_1
            conv_1_1 = self._conv_stage(input_tensor=input_tensor, k_size=3,
                                        out_dims=64, name='conv1_1')

            # conv stage 1_2
            conv_1_2 = self._conv_stage(input_tensor=conv_1_1, k_size=3,
                                        out_dims=64, name='conv1_2')

            # pool stage 1
            pool1 = self.maxpooling(inputdata=conv_1_2, kernel_size=2,
                                    stride=2, name='pool1')

            # conv stage 2_1
            conv_2_1 = self._conv_stage(input_tensor=pool1, k_size=3,
                                        out_dims=128, name='conv2_1')

            # conv stage 2_2
            conv_2_2 = self._conv_stage(input_tensor=conv_2_1, k_size=3,
                                        out_dims=128, name='conv2_2')

            # pool stage 2
            pool2 = self.maxpooling(inputdata=conv_2_2, kernel_size=2,
                                    stride=2, name='pool2')

            # conv stage 3_1
            conv_3_1 = self._conv_stage(input_tensor=pool2, k_size=3,
                                        out_dims=256, name='conv3_1')

            # conv_stage 3_2
            conv_3_2 = self._conv_stage(input_tensor=conv_3_1, k_size=3,
                                        out_dims=256, name='conv3_2')

            # conv stage 3_3
            conv_3_3 = self._conv_stage(input_tensor=conv_3_2, k_size=3,
                                        out_dims=256, name='conv3_3')

            ret = (conv_1_1, conv_1_2, conv_2_1, conv_2_2,
                   conv_3_1, conv_3_2, conv_3_3)

        return ret

5. config資料夾下的所有檔案

config資料夾下只有一個global_config.py檔案,這裡面儲存了網路引數。注:如果GPU的記憶體比較小的話,建議把引數改小一些。我是用的GTX1060 3G記憶體,直接執行都會記憶體溢位,所以可以考慮把影象尺寸改小一點。下面先給出檔案中的引數設定(我只改了一點點):

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : global_config.py
from easydict import EasyDict as edict

__C = edict()
# Consumers can get config by: from config import cfg

cfg = __C

# Train options
__C.TRAIN = edict()

__C.TRAIN.EPOCHS = 20010
__C.TRAIN.LEARNING_RATE = 0.002
# Set the GPU resource used during training process
__C.TRAIN.GPU_MEMORY_FRACTION = 0.95
# Set the GPU allow growth parameter during tensorflow training process
__C.TRAIN.TF_ALLOW_GROWTH = True
__C.TRAIN.BATCH_SIZE = 1
__C.TRAIN.IMG_HEIGHT = 240
__C.TRAIN.IMG_WIDTH = 360

# Test options
__C.TEST = edict()

# Set the GPU resource used during testing process
__C.TEST.GPU_MEMORY_FRACTION = 0.8
# Set the GPU allow growth parameter during tensorflow testing process
__C.TEST.TF_ALLOW_GROWTH = True
__C.TEST.BATCH_SIZE = 1
__C.TEST.IMG_HEIGHT = 240
__C.TEST.IMG_WIDTH = 360

6. data2txt.py檔案

該檔案是自己寫的,用於生成train.txt檔案,直接給出程式碼:

import os


def data2txt(data_rootdir):
    # 讀取兩個資料夾的所有影象並判斷是否相等
    images = os.listdir(data_rootdir + 'data/')
    labels = os.listdir(data_rootdir + 'gt/')
    images.sort()
    labels.sort()

    image_len = len(images)
    label_len = len(labels)

    assert image_len == label_len

    # 開啟文字並寫入路徑
    trainText = open(data_rootdir + 'train.txt', 'w')
    for i in range(image_len):
        image_dir = data_rootdir + 'data/' + images[i] + ' '
        label_dir = data_rootdir + 'gt/' + labels[i] + '\n'

        trainText.write(image_dir)
        trainText.write(label_dir)

    trainText.close()
    print('finished!')


if __name__ == '__main__':
    data2txt('./data/training_data/')

7. train_model.py檔案

train_model.py用於訓練檔案,訓練之前主要是檢查引數的設定,設定好直接執行該檔案即可開始訓練,下面給出該檔案的程式碼:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : train_model.py.py

import os
import os.path as ops
import argparse
import time

import tensorflow as tf
import numpy as np
import glog as log

from data_provider import data_provider
from config import global_config
from attentive_gan_model import derain_drop_net
from attentive_gan_model import tf_ssim

CFG = global_config.cfg
VGG_MEAN = [103.939, 116.779, 123.68]


def init_args():

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_dir', type=str, default='./data/training_data/', help='The dataset dir')
    parser.add_argument('--weights_path', type=str,
                        # default='',
                        help='The pretrained weights path', default=None)

    return parser.parse_args()


def train_model(dataset_dir, weights_path=None):

    # 構建資料集
    with tf.device('/gpu:0'):
        train_dataset = data_provider.DataSet(ops.join(dataset_dir, 'train.txt'))

        # 宣告tensor
        input_tensor = tf.placeholder(dtype=tf.float32,
                                      shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3],
                                      name='input_tensor')
        label_tensor = tf.placeholder(dtype=tf.float32,
                                      shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3],
                                      name='label_tensor')
        mask_tensor = tf.placeholder(dtype=tf.float32,
                                     shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 1],
                                     name='mask_tensor')
        lr_tensor = tf.placeholder(dtype=tf.float32,
                                   shape=[],
                                   name='learning_rate')
        phase_tensor = tf.placeholder(dtype=tf.string, shape=[], name='phase')

        # 宣告ssim計算類
        ssim_computer = tf_ssim.SsimComputer()

        # 宣告網路
        derain_net = derain_drop_net.DeRainNet(phase=phase_tensor)

        gan_loss, discriminative_loss, net_output = derain_net.compute_loss(
            input_tensor=input_tensor,
            gt_label_tensor=label_tensor,
            mask_label_tensor=mask_tensor,
            name='derain_net_loss')

        train_vars = tf.trainable_variables()

        ssim = ssim_computer.compute_ssim(tf.image.rgb_to_grayscale(net_output),
                                          tf.image.rgb_to_grayscale(label_tensor))

        d_vars = [tmp for tmp in train_vars if 'discriminative_loss' in tmp.name]
        g_vars = [tmp for tmp in train_vars if 'attentive_' in tmp.name and 'vgg_feats' not in tmp.name]
        vgg_vars = [tmp for tmp in train_vars if "vgg_feats" in tmp.name]

        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(lr_tensor, global_step, 100000, 0.1, staircase=True)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            d_optim = tf.train.AdamOptimizer(learning_rate).minimize(discriminative_loss, var_list=d_vars)
            g_optim = tf.train.MomentumOptimizer(
                learning_rate=learning_rate,
                momentum=tf.constant(0.9, tf.float32)).minimize(gan_loss, var_list=g_vars)

        # Set tf saver
        saver = tf.train.Saver()
        model_save_dir = './model/derain_gan_tensorflow'
        if not ops.exists(model_save_dir):
            os.makedirs(model_save_dir)
        train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
        model_name = 'derain_gan_{:s}.ckpt'.format(str(train_start_time))
        model_save_path = ops.join(model_save_dir, model_name)

        # Set tf summary
        tboard_save_path = './tboard/derain_gan_tensorflow'
        if not ops.exists(tboard_save_path):
            os.makedirs(tboard_save_path)
        g_loss_scalar = tf.summary.scalar(name='gan_loss', tensor=gan_loss)
        d_loss_scalar = tf.summary.scalar(name='discriminative_loss', tensor=discriminative_loss)
        ssim_scalar = tf.summary.scalar(name='image_ssim', tensor=ssim)
        lr_scalar = tf.summary.scalar(name='learning_rate', tensor=lr_tensor)
        d_summary_op = tf.summary.merge([d_loss_scalar, lr_scalar])
        g_summary_op = tf.summary.merge([g_loss_scalar, ssim_scalar])

        # Set sess configuration
        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.per_process_gpu_me