1. 程式人生 > >Residual Dense Network for Image Super-Resolution 程式碼詳解

Residual Dense Network for Image Super-Resolution 程式碼詳解

Residual Dense Network for Image Super-Resolution

  • 以下是 RND論文Tensorflow版本實現的原始碼解析,我假設你已經瞭解Python的基本語法,和Tensorflow的基本用法,如果你對它們不是很熟悉,請到它們的官網查閱tutorial。

  • 以下所有程式碼你都可以在我的倉庫找到,chinese_annotation資料夾下是我新增中文註釋後的版本,在main.py裡面我做了一些修改,Feel free to tune the hyperparameters in it~

  • repo的readme.md說明了如何執行程式。

如果有不對的地方,還請大家指正!

Notice

使用TensorFlow搭建模型時,我們一般先將模型按照神經網路的結構搭建起來,這時TensorFlow只會建立好computation graph,實際的資料還需要等執行的時候feed in.

overview of the RDN model.

So,let’s get started from the model.py

這裡主要負責一些初始化工作:

  • sess用來傳遞一個TensorFlow會話(不懂也沒關係)
  • is_train和is_eval用來控制訓練還是測試,img_size是輸入圖片大小
  • c_dim是圖片通道數,用的是RGB圖所以c_dim=3
  • scale是超分辨放大的規模 x2或x3或x4,batch_size
  • batch_size就是batch_size了,哈哈
  • D是模型中Residual Dense Block塊的個數
  • C是每個Residual Dense Block塊中conv層數量
  • 模型中所有層輸出的feature maps不是 G G
    就是 G 0 G_0 ,詳細見論文
  • kernel_size是卷積核的大小
class RDN(object):

	def __init__(self,
				 sess,
				 is_train,
				 is_eval,
				 image_size,
				 c_dim,
				 scale,
				 batch_size,
				 D,
				 C,
				 G,
				 G0,
				 kernel_size ):

		self.sess = sess
		self.is_train = is_train
		self.is_eval = is_eval
		self.image_size = image_size
		self.c_dim = c_dim
		self.scale = scale
		self.batch_size = batch_size
		self.D = D
		self.C = C
		self.G = G
		self.G0 = G0
		self.kernel_size = kernel_size

Shallow Feature Extraction Net

淺層特徵提取部分,見網路的前兩個藍色部分塊,兩個conv層,產生F_-1和F_0,最後輸出有G個feature maps。

卷積核是一個四維的tensor -->(ks, ks, self.c_dim, G0)

  • 前兩個引數是卷積核kernel的size
  • 第三個是輸入tensor的通道數
  • 第四個是輸出tensor的通道數

偏置單元和輸出通道數保持一致

def SFEParams(self):
    """
    淺層特徵提取部分(兩個conv層,產生F_-1和F_0)
    最後輸出有G個feature maps
    :return:
    """
    G = self.G
    G0 = self.G0
    ks = self.kernel_size
    weightsS = {
        'w_S_1': tf.Variable(tf.random_normal([ks, ks, self.c_dim, G0], stddev=0.01), name='w_S_1'),
        'w_S_2': tf.Variable(tf.random_normal([ks, ks, G0, G], stddev=0.01), name='w_S_2')
    }
    biasesS = {
        'b_S_1': tf.Variable(tf.zeros([G0], name='b_S_1')),
        'b_S_2': tf.Variable(tf.zeros([G], name='b_S_2'))
    }

    return weightsS, biasesS

	

RDB Block

residual dense block,也就是網路中3個紅色塊部分,每個RDB中細節見下圖。

替代文字

第i個RDB塊接受第i-1個RDB塊傳來的輸出作為輸入,在每個RDB塊中,每一層的輸出都會送個它的後面所有層。第D個RDB塊的第c層輸出的公式如下:
$F_{d,c}=\sigma(W_{d,c}[F_{d-1},F_{d,1},F_{d,2}…F_{d,c-1}]) $

其中 [ F d 1 , F d , 1 , F d , 2 . . . F d , c 1 ] [F_{d-1},F_{d,1},F_{d,2}...F_{d,c-1}] 就是將它們concat在一起,也即包含 [ G 0 + ( c 1 ) G ] [G_0+(c-1)*G] 個feature maps。

每個RDB塊由以下模組裝成(conv1 -> relu1 -> conv2 -> relu2 … -> convC ->reluC -> concatnation -> 1*1 conv -> local residual)

def RDBParams(self):
    """
    RDB部分

    中間一個RDB塊(conv1 -> relu1 -> conv2 -> relu2 .... -> convC ->reluC
            -> concatnation -> 1*1 conv -> local residual)
    :return:
    """
    weightsR = {}
    biasesR = {}

    D = self.D
    C = self.C
    G = self.G
    G0 = self.G0
    ks = self.kernel_size

    for i in range(1, D + 1):
        for j in range(1, C + 1):
            # dense conv layers in i-th dense block
            weightsR.update({'w_R_%d_%d' % (i, j): tf.Variable(tf.random_normal([ks, ks, G * j, G], stddev=0.01),
                                                               name='w_R_%d_%d' % (i, j))})
            biasesR.update({'b_R_%d_%d' % (i, j): tf.Variable(tf.zeros([G], name='b_R_%d_%d' % (i, j)))})
        # local feature fusion in i-th dense block
        weightsR.update({'w_R_%d_%d' % (i, C + 1): tf.Variable(
            tf.random_normal([1, 1, G * (C + 1), G], stddev=0.01), name='w_R_%d_%d' % (i, C + 1))})
        biasesR.update({'b_R_%d_%d' % (i, C + 1): tf.Variable(tf.zeros([G], name='b_R_%d_%d' % (i, C + 1)))})

    return weightsR, biasesR

def RDBs(self, input_layer):
    rdb_concat = list()
    rdb_in = input_layer
    for i in range(1, self.D + 1):
        x = rdb_in
        for j in range(1, self.C + 1):
            tmp = tf.nn.conv2d(x, self.weightsR['w_R_%d_%d' % (i, j)], strides=[1, 1, 1, 1], padding='SAME') + \
                  self.biasesR['b_R_%d_%d' % (i, j)]
            tmp = tf.nn.relu(tmp)
            # 在最後一個維度做concat操作
            x = tf.concat([x, tmp], axis=3)

        # local feature fusion
        x = tf.nn.conv2d(x, self.weightsR['w_R_%d_%d' % (i, self.C + 1)], strides=[1, 1, 1, 1], padding='SAME') + \
            self.biasesR['b_R_%d_%d' % (i, self.C + 1)]
        # local residual learning
        rdb_in = tf.add(x, rdb_in)
        # 為global feature fusion做準備
        rdb_concat.append(rdb_in)
    # 在最後一個維度做concat
    return tf.concat(rdb_concat, axis=3)

Dense Feature Fusion

這一部分主要是將前面所有RDB的結果進行一個特徵融合,方法和RDB塊中最後的concat操作類似,就不再贅述了,參閱模型整體圖的三個紅色塊後面的concat操作,然後對concated tensor做 1 1 1*1 卷積到G個feature maps,再進行 3*3 卷積準備進行Global residual learning。公式如下:

F G F = H G F F ( [ F 1 , F 2 . . . F D ] ) F_{GF}=H_{GFF}([F_1,F_2... F_D])

def DFFParams(self):
    """
    dense feature fusion part
    :return:
    """
    D = self.D
    C = self.C
    G = self.G
    G0 = self.G0
    ks = self.kernel_size
    weightsD = {
        'w_D_1': tf.Variable(tf.random_normal([1, 1, G * D, G0], stddev=0.01), name='w_D_1'),
        'w_D_2': tf.Variable(tf.random_normal([ks, ks, G0, G0], stddev=0.01), name='w_D_2')
    }
    biasesD = {
        'b_D_1': tf.Variable(tf.zeros([G0], name='b_D_1')),
        'b_D_2': tf.Variable(tf.zeros([G0], name='b_D_2'))
    }

    return weightsD, biasesD

Upscale部分

這算是論文中的亮點之一了,RDN和以前的一些方法不一樣,以前許多模型都是先對低解析度的影象先進行upscale(如bicubic放大)到高解析度的影象,然後再輸入到神經網路進行計算;而RDN借鑑了ESPNN論文中提出的sub-pixel convolution方法,先將低解析度的影象輸入到神經網路進行計算,最後進行所謂的亞畫素卷積。

亞畫素卷積就是形如以下所示:

sub-pixel

本質上就是將低解析度特徵,按照特定位置,週期性的插入到高解析度影象中,可以通過顏色觀測到上圖的插入方式。

為了更好的理解,你可以這樣想象,假設最開始輸入的低解析度圖片是(Hight,Width,3)的向量,然後一系列操作之後,神經網路輸出的(Hight,Width,9) 的低分辨特徵,因為我們想放大3倍,所以最後一個維度就是放大的scale乘上想要輸出的channel,根據顏色你可以清楚的看到是怎麼將亞畫素進行“組裝”的。來個程式碼演示一下,你可以執行玩玩

# upsacale 測試程式碼
import numpy as np

a = np.ones(shape=[4, 4, 3 * 3])
for i in range(9):
	a[:, :, i] = a[:, :, i] * (i + 1)
print(a)
a = np.reshape(a, newshape=(4, 4, 3, 3))
print(a.shape)

# 將a分為4個,在第0個asix上切分
a = np.split(a, 4, 0) #  4,[1,4,3,3]
a = np.concatenate([np.squeeze(x) for x in a], 1) # [4,3*4,3]

# 重複一次以上操作
a = np.split(a, 4, 0) # 4,[1,3*4,3]
a = np.concatenate([np.squeeze(x) for x in a], 1) # [3*4,3*4]

print(a)
print(a.shape)
[[[1. 2. 3. 4. 5. 6. 7. 8. 9.]
  [1. 2. 3. 4. 5. 6. 7. 8. 9.]
  [1. 2. 3. 4. 5. 6. 7. 8. 9.]
  [1. 2. 3. 4. 5. 6. 7. 8. 9.]]

 [[1. 2. 3. 4. 5. 6. 7. 8. 9.]
  [1. 2. 3. 4. 5. 6. 7. 8. 9.]
  [1. 2. 3. 4. 5. 6. 7. 8. 9.]
  [1. 2. 3. 4. 5. 6. 7. 8. 9.]]

 [[1. 2. 3. 4. 5. 6. 7. 8. 9.]
  [1. 2. 3. 4. 5. 6. 7. 8. 9.]
  [1. 2. 3. 4. 5. 6. 7. 8. 9.]
  [1. 2. 3. 4. 5. 6. 7. 8. 9.]]

 [[1. 2. 3. 4. 5. 6. 7. 8. 9.]
  [1. 2. 3. 4. 5. 6. 7. 8. 9.]
  [1. 2. 3. 4. 5. 6. 7. 8. 9.]
  [1. 2. 3. 4. 5. 6. 7. 8. 9.]]]
(4, 4, 3, 3)
[[1. 2. 3. 1. 2. 3. 1. 2. 3. 1. 2. 3.]
 [4. 5. 6. 4. 5. 6. 4. 5. 6. 4. 5. 6.]
 [7. 8. 9. 7. 8. 9. 7. 8. 9. 7. 8. 9.]
 [1. 2. 3. 1. 2. 3. 1. 2. 3. 1. 2. 3.]
 [4. 5. 6. 4. 5. 6. 4. 5. 6. 4. 5. 6.]
 [7. 8. 9. 7. 8. 9. 7. 8. 9. 7. 8. 9.]
 [1. 2. 3. 1. 2. 3. 1. 2. 3. 1. 2. 3.]
 [4. 5. 6. 4. 5. 6. 4. 5. 6. 4. 5. 6.]
 [7. 8. 9. 7. 8. 9. 7. 8. 9. 7. 8. 9.]
 [1. 2. 3. 1. 2. 3. 1. 2. 3. 1. 2. 3.]
 [4. 5. 6. 4. 5. 6. 4. 5. 6. 4. 5. 6.]
 [7. 8. 9. 7. 8. 9. 7. 8. 9. 7. 8. 9.]]
(12, 12)
def UPNParams(self):
    # upscale part
    G0 = self.G0
    weightsU = {
        'w_U_1': tf.Variable(tf.random_normal([5, 5, G0, 64], stddev=0.01), name='w_U_1'),
        'w_U_2': tf.Variable(tf.random_normal([3, 3, 64, 32], stddev=0.01), name='w_U_2'),
        'w_U_3': tf.Variable(
            tf.random_normal([3, 3, 32, self.c_dim * self.scale * self.scale], stddev=np.sqrt(2.0 / 9 / 32)),
            name='w_U_3')
    }
    biasesU = {
        'b_U_1': tf.Variable(tf.zeros([64], name='b_U_1')),
        'b_U_2': tf.Variable(tf.zeros([32], name='b_U_2')),
        'b_U_3': tf.Variable(tf.zeros([self.c_dim * self.scale * self.scale], name='b_U_3'))
    }

    return weightsU, biasesU
    
def UPN(self, input_layer):
    # 輸出為 64 feature maps
    x = tf.nn.conv2d(input_layer, self.weightsU['w_U_1'], strides=[1, 1, 1, 1], padding='SAME') + self.biasesU[
        'b_U_1']
    x = tf.nn.relu(x)
    # 輸出為 32 feature maps
    x = tf.nn.conv2d(x, self.weightsU['w_U_2'], strides=[1, 1, 1, 1], padding='SAME') + self.biasesU['b_U_2']
    x = tf.nn.relu(x)
    # 輸出為 self.c_dim (3)* self.scale * self.scale 個 feature maps ,即低解析度特徵
    x = tf.nn.conv2d(x, self.weightsU['w_U_3'], strides=[1, 1, 1, 1], padding='SAME') + self.biasesU['b_U_3']
    # 將height和width放大
    x = self.PS(x, self.scale)

    return x

    
def PS(self, X, r):
    # Main OP that you can arbitrarily use in you tensorflow code
    # 在feature maps維上,分成3個Tensor,每個的shape應該是(batch_size,H,W, self.scale * self.scale)
    Xc = tf.split(X, 3, 3)
    if self.is_train:
        X = tf.concat([self._phase_shift(x, r) for x in Xc], 3)  # Do the concat RGB
    else:
        X = tf.concat([self._phase_shift_test(x, r) for x in Xc], 3)  # Do the concat RGB
    return X

# NOTE: train with batch size
def _phase_shift(self, I, r):
    """
    把最後一位放大的scale轉到Height和weight上
    :param I:
    :param r:放大因子
    :return:
    """
    # Helper function with main phase shift operation
    bsize, a, b, c = I.get_shape().as_list()
    X = tf.reshape(I, (self.batch_size, a, b, r, r))
    X = tf.split(X, a, 1)  # a, [bsize, b, r, r]
    X = tf.concat([tf.squeeze(x) for x in X], 2)  # bsize, b, a*r, r
    X = tf.split(X, b, 1)  # b, [bsize, a*r, r]
    X = tf.concat([tf.squeeze(x) for x in X], 2)  # bsize, a*r, b*r
    return tf.reshape(X, (self.batch_size, a * r, b * r, 1))

# NOTE: test without batchsize
def _phase_shift_test(self, I, r):
    bsize, a, b, c = I.get_shape().as_list()
    X = tf.reshape(I, (1, a, b, r, r))
    X = tf.split(X, a, 1)  # a, [bsize, b, r, r]
    X = tf.concat([tf.squeeze(x) for x in X], 1)  # bsize, b, a*r, r
    X = tf.split(X, b, 0)  # b, [bsize, a*r, r]
    X = tf.concat([tf.squeeze(x) for x in X], 1)  # bsize, a*r, b*r
    return tf.reshape(X, (1, a * r, b * r, 1))

Build all the blocks above together !

有了上面的模組,現在就可以將模型“堆起來“了!依次是四大部分(SFE,RDBs,DFF,UPN)

def build_model(self, images_shape, labels_shape):
    self.images = tf.placeholder(tf.float32, images_shape, name='images')
    # label是ground truth
    self.labels = tf.placeholder(tf.float32, labels_shape, name='labels')

    self.weightsS, self.biasesS = self.SFEParams()
    self.weightsR, self.biasesR = self.RDBParams()
    self.weightsD, self.biasesD = self.DFFParams()
    self.weightsU, self.biasesU = self.UPNParams()
    # 最後一個conv層,輸入是upscale後的RGB圖
    self.weight_final = tf.Variable(
        tf.random_normal([self.kernel_size, self.kernel_size, self.c_dim, self.c_dim], stddev=np.sqrt(2.0 / 9 / 3)),
        name='w_f')
    self.bias_final = tf.Variable(tf.zeros([self.c_dim], name='b_f')),

    self.pred = self.model()
    # MSE 均方誤差損失函式
    self.loss = tf.reduce_mean(tf.square(self.labels - self.pred))
    self.summary = tf.summary.scalar('loss', self.loss)
    self.saver = tf.train.Saver()


def model(self):
    # SFE部分
    F_1 = tf.nn.conv2d(self.images, self.weightsS['w_S_1'], strides=[1, 1, 1, 1], padding='SAME') + self.biasesS[
        'b_S_1']
    F0 = tf.nn.conv2d(F_1, self.weightsS['w_S_2'], strides=[1, 1, 1, 1], padding='SAME') + self.biasesS['b_S_2']

    # RDBs部分
    FD = self.RDBs(F0)
    
    # DFF部分,1*1卷積再3*3卷積
    FGF1 = tf.nn.conv2d(FD, self.weightsD['w_D_1'], strides=[1, 1, 1, 1], padding='SAME') + self.biasesD['b_D_1']
    FGF2 = tf.nn.conv2d(FGF1, self.weightsD['w_D_2'], strides=[1, 1, 1, 1], padding='SAME') + self.biasesD['b_D_2']
    
    # Global Residual Learning部分
    FDF = tf.add(FGF2, F_1)
    
    # UPscale部分
    FU = self.UPN(FDF)
    
    # 最後一個卷積操作後的到高解析度圖片
    IHR = tf.nn.conv2d(FU, self.weight_final, strides=[1, 1, 1, 1], padding='SAME') + self.bias_final

    return IHR

Train

下面會用到一些輔助函式,在utils.py檔案中

def train(self, config):
    print("\nPrepare Data...\n")
    # 儲存資料為.h5格式
    input_setup(config)
    data_dir = get_data_dir(config.checkpoint_dir, config.is_train)
    # 訓練樣本數
    data_num = get_data_num(data_dir)

    images_shape = [None, self.image_size, self.image_size, self.c_dim]
    labels_shape = [None, self.image_size * self.scale, self.image_size * self.scale, self.c_dim]
    self.build_model(images_shape, labels_shape)
    # adam 加速
    self.train_op = tf.train.AdamOptimizer(learning_rate=config.learning_rate).minimize(self.loss)
    tf.global_variables_initializer().run(session=self.sess)
    # merged_summary_op = tf.summary.merge_all()
    # 儲存計算圖到檔案(用於tensorboard視覺化)
    # summary_writer = tf.summary.FileWriter(config.checkpoint_dir, self.sess.graph)

    # 繼續模型之前的計算
    counter = self.load(config.checkpoint_dir)
    time_ = time.time()
    print("\nNow Start Training...\n")
    for ep in range(config.epoch):
        # Run by batch images
        # 根據batch_size將資料分組
        batch_idxs = data_num // config.batch_size
        for idx in range(0, batch_idxs):

            #with tf.device("/gpu:0"):
            batch_images, batch_labels = get_batch(data_dir, data_num, config.batch_size)
            # 每一個batch counter加一,也就是平常我們說的iteration
            counter += 1

            _, err = self.sess.run([self.train_op, self.loss],
                                   feed_dict={self.images: batch_images, self.labels: batch_labels})

            if counter % 10 == 0:
                print("Epoch: [%2d], batch: [%2d/%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" % (
                    (ep + 1), idx, batch_idxs, counter, time.time() - time_, err))

            # 每100個batch就儲存一次模型
            if counter % 100 == 0:
                self.save(config.checkpoint_dir, counter)

            # summary_str = self.sess.run(merged_summary_op)
            # summary_writer.add_summary(summary_str, counter)

            if counter > 0 and counter == batch_idxs * config.epoch:
                return

def load(self, checkpoint_dir):
    """
    從指定目錄載入模型已經計算的部分,並接著計算
    :param checkpoint_dir:
    :return:
    """
    print("\nReading Checkpoints.....\n")
    model_dir = "%s_%s_%s_%s_x%s" % ("rdn", self.D, self.C, self.G, self.scale)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
    """
    關於tf.train.get_checkpoint_state(checkpoint_dir,latest_filename=None):
        返回:checkpoint檔案CheckpointState proto型別的內容,
            其中有model_checkpoint_path和all_model_checkpoint_paths兩個屬性。

            model_checkpoint_path:儲存了最新的tensorflow模型檔案的檔名,
            all_model_checkpoint_paths:則有未被刪除的所有tensorflow模型檔案的檔名。
    """
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)

    if ckpt and ckpt.model_checkpoint_path:
        ckpt_path = str(ckpt.model_checkpoint_path)
        self.saver.restore(self.sess, os.path.join(os.getcwd(), ckpt_path))
        step = int(os.path.basename(ckpt_path).split('-')[1])
        print("\nCheckpoint Loading Success! %s\n" % ckpt_path)
    else:
        step = 0
        print("\nCheckpoint Loading Failed! \n")

    return step

def save(self, checkpoint_dir, step):
    model_name = "RDN.model"
    model_dir = "%s_%s_%s_%s_x%s" % ("rdn", self.D, self.C, self.G, self.scale)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    self.saver.save(self.sess,
                    os.path.join(checkpoint_dir, model_name),
                    global_step=step)

下面是utils.py中的一些輔助函式

這部分程式碼有點多,大家如果看不明白可以暫時跳過。: )
71-77行程式碼我也沒看明白是做了一個什麼操作,如果有知道的,請指教,謝謝!

def input_setup(config):
	"""
		Read image files and make their sub-images and saved them as a h5 file format
	"""
	# data為所有圖片路徑組成的列表
	data = prepare_data(config)
	make_sub_data(data, config)
    

def prepare_data(config):
	"""
	根據config.isTrain屬性返回包含訓練集或測試集的圖片路徑
	:param config:
	:return: 所有圖片路徑組成的列表
	"""
	if config.is_train:
		data_dir = os.path.join(os.path.join(os.getcwd(), "Train"), config.train_set)
		# 獲取當前路徑下的所有png圖片
		data = glob.glob(os.path.join(data_dir, "*.png"))
	else:
		if config.test_img != "":
			data = [os.path.join(os.getcwd(), config.test_img)]
		else:
			data_dir = os.path.join(os.path.join(os.getcwd(), "Test"), config.test_set)
			data = glob.glob(os.path.join(data_dir, "*.bmp"))
	return data

def make_sub_data(data, config):
	"""
	取樣產生更多樣本資料
	:param data: 源資料路徑
	:param config:
	:return:
	"""
	# 是否使用MATLAB中的bicubic
	if config.matlab_bicubic:
		import matlab.engine
		eng = matlab.engine.start_matlab()
		mdouble = matlab.double
	else:
		eng = None
		mdouble = None

	times = 0
	for i in range(len(data)):
        # 對圖片進行預處理
		input_, label_, = preprocess(data[i], config.scale, eng, mdouble)
		if len(input_.shape) == 3:
			h, w, c = input_.shape
		else:
			h, w = input_.shape
		# 如果不是訓練過程
		if not config.is_train:
			input_ = input_ / 255.0
			label_ = label_ / 255.0
			make_data_hf(input_, label_, config, times)
			return data

		for x in range(0, h * config.scale - config.image_size * config.scale + 1, config.stride * config.scale):
			for y in range(0, w * config.scale - config.image_size * config.scale + 1, config.stride * config.scale):
				# 滑動視窗取樣資料(data augmentation)
				sub_label = label_[x: x + config.image_size * config.scale, y: y + config.image_size * config.scale]

				sub_label = sub_label.reshape(
					[config.image_size * config.scale, config.image_size * config.scale, config.c_dim])

				# 將取樣的ground truth RGB圖片轉到YCrCb顏色域下
				t = cv2.cvtColor(sub_label, cv2.COLOR_BGR2YCR_CB)
				
                # 這裡做了一個判斷,暫時沒搞明白
				t = t[:, :, 0]
				gx = t[1:, 0:-1] - t[0:-1, 0:-1]
				gy = t[0:-1, 1:] - t[0:-1, 0:-1]
				Gxy = (gx ** 2 + gy ** 2) ** 0.5
				r_gxy = float((Gxy > 10).sum()) / ((config.image_size * config.scale) ** 2) * 100
				if r_gxy < 10:
					continue

				sub_label = sub_label / 255.0
                
				# 取樣的ground truth RGB圖片對應的低解析度影象
				x_i = x // config.scale
				y_i = y // config.scale
				sub_input = input_[x_i: x_i + config.image_size, y_i: y_i + config.image_size]
				sub_input = sub_input.reshape([config.image_size, config.image_size, config.c_dim])
				sub_input = sub_input / 255.0

				# checkimage(sub_input)
				# checkimage(sub_label)

				# 將取樣的低解析度影象和ground truth影象儲存為.h5格式
				save_flag = make_data_hf(sub_input, sub_label, config, times)
				# 一旦儲存為.h5檔案失敗,就停止對資料集的取樣操作
				if not save_flag:
					return data
				times += 1

		print("image: [%2d], total: [%2d]" % (i, len(data)))

	if config.matlab_bicubic:
		eng.quit()

	return data


def preprocess(path, scale=3, eng=None, mdouble=None):
	"""
	對單張圖片預處理
	:param path: 圖片地址
	:param scale: 縮放規模
	:param eng: MATLAB呼叫引擎
	:param mdouble: MATLAB double
	:return: (處理後(縮小後)的圖片,ground truth的圖片)  tuple
	"""
	img = imread(path)
	# 裁剪,使得圖片的長寬可以整除scale
	label_ = modcrop(img, scale)
	# eng是MATLAB呼叫引擎,如果沒有安裝MATLAB的python支援庫,則呼叫cv2中的bicubic
	if eng is None:
		input_ = cv2.resize(label_, None, fx=1.0 / scale, fy=1.0 / scale, interpolation=cv2.INTER_CUBIC)
	else:
		input_ = np.asarray(eng.imresize(mdouble(label_.tolist()), 1.0 / scale, 'bicubic'))

	# 最後一維翻轉(因為OpenCV中的imread()讀取圖片的順序不是R、G、B三個次序,而是R、B、G)
	input_ = input_[:, :, ::-1]
	label_ = label_[:, :, ::-1]

	return input_, label_


def modcrop(img, scale=3):
	"""
	將原影象的長寬都改變成scale的引數,以便於取樣
	:param img:
	:param scale:
	:return:
	"""
	if len(img.shape) == 3:
		h, w, _ = img.shape
		h = (h // scale) * scale
		w = (w // scale) * scale
		img = img[0:h, 0:w, :]
	else:
		h, w = img.shape
		h = (h // scale) * scale
		w = (w // scale) * scale
		img = img[0:h, 0:w]
	return img


def make_data_hf(input_, label_, config, times):
	"""
	將低解析度圖片和ground truth圖片儲存為.h5格式
	hf means hfive  ooops.. :)
	:param input_:
	:param label_:
	:param config:
	:param times:
	:return: bool
	"""
	if not os.path.isdir(os.path.join(os.getcwd(), config.checkpoint_dir)):
		os.makedirs(os.path.join(os.getcwd(), config.checkpoint_dir))

	if config.is_train:
		savepath = os.path.join(os.path.join(os.getcwd(), config.checkpoint_dir), 'train.h5')
	else:
		savepath = os.path.join(os.path.join(os.getcwd(), config.checkpoint_dir), 'test.h5')
	# 第一次儲存到.h5時,以“w”模式開啟
	if times == 0:
		if os.path.exists(savepath):
			print("\n%s have existed!\n" % (savepath))
			return False
		else:
			hf = h5py.File(savepath, 'w')

			# 訓練
			if config.is_train:
				# chunck 分塊儲存
				input_h5 = hf.create_dataset("input", (1, config.image_size, config.image_size, config.c_dim),
											 maxshape=(None, config.image_size, config.image_size, config.c_dim),
											 chunks=(1, config.image_size, config.image_size, config.c_dim),
											 dtype='float32')
				label_h5 = hf.create_dataset("label", (1, config.image_size * config.scale, config.image_size * config.scale, config.c_dim),
											 maxshape=( None, config.image_size * config.scale,config.image_size * config.scale,config.c_dim),
											 chunks=(1, config.image_size * config.scale, config.image_size * config.scale,config.c_dim), dtype='float32')
			# 測試
			else:
				input_h5 = hf.create_dataset("input", (1, input_.shape[0], input_.shape[1], input_.shape[2]),
											 maxshape=(None, input_.shape[0], input_.shape[1], input_.shape[2]),
											 chunks=(1, input_.shape[0], input_.shape[1], input_.shape[2]),
											 dtype='float32')
				label_h5 = hf.create_dataset("label", (1, label_.shape[0], label_.shape[1], label_.shape[2]),
											 maxshape=(None, label_.shape[0], label_.shape[1], label_.shape[2]),
											 chunks=(1, label_.shape[0], label_.shape[1], label_.shape[2]),
											 dtype='float32')
	# 其它形式下,用“a”模式開啟
	else:
		hf = h5py.File(savepath, 'a')
		input_h5 = hf["input"]
		label_h5 = hf["label"]

	if config.is_train:
		input_h5.resize([times + 1, config.image_size, config.image_size, config.c_dim])
		input_h5[times: times + 1] = input_
		label_h5.resize([times + 1, config.image_size * config.scale, config.image_size * config.scale, config.c_dim])
		label_h5[times: times + 1] = label_
	else:
		input_h5.resize([times + 1, input_.shape[0], input_.shape[1], input_.shape[2]])
		input_h5[times: times + 1] = input_
		label_h5.resize([times + 1, label_.shape[0], label_.shape[1], label_.shape[2]])
		label_h5[times: times + 1] = label_

	hf.close()
	return True

def get_data_dir(checkpoint_dir, is_train):
	"""
	獲取資料集的目錄(訓練和測試模式)
	:param checkpoint_dir:
	:param is_train:
	:return: 對應的.h5檔案
	"""
	if is_train:
		return os.path.join(os.path.join(os.getcwd(), checkpoint_dir), 'train.h5')
	else:
		return os.path.join(os.path.join(os.getcwd(), checkpoint_dir), 'test.h5')


def get_data_num(path):
	"""
	獲取.h5檔案的input資料集中樣本個數
	:param path:
	:return:
	"""
	with h5py.File(path, 'r') as hf:
		input_ = hf['input']
		return input_.shape[0]
    
def get_batch(path, data_num, batch_size):
	"""
	獲取batch_size個樣本
	:param path: 資料集地址
	:param data_num: 資料集總數
	:param batch_size: batch大小
	:return:資料增強後的資料集合  (batch_size,H,W,3)
	"""
	with h5py.File(path, 'r') as hf:
		input_ = hf['input']
		label_ = hf['label']
		# batch size
		random_batch = np.random.rand(batch_size) * (data_num - 1)  # batch size 個樣本資料的下標
		batch_images = np.zeros([batch_size, input_[0].shape[0], input_[0].shape[1], input_[0].shape[2]])
		batch_labels = np.zeros([batch_size, label_[0].shape[0], label_[0].shape[1], label_[0].shape[2]])
		for i in range(batch_size):
			batch_images[i, :, :, :] = np.asarray(input_[int(random_batch[i])])
			batch_labels[i, :, :, :] = np.asarray(label_[int(random_batch[i])])
		# data augmentation
		random_aug = np.random.rand(2)
		# 翻轉或旋轉
		batch_images = augmentation(batch_images, random_aug)
		batch_labels = augmentation(batch_labels, random_aug)
		return batch_images, batch_labels
    
def augmentation(batch, random):
	if random[0] < 0.3:
		# 在batch的第shape[1]上,上下翻轉
		batch_flip = np.flip(batch, 1)
	elif random[0] > 0.7:
		# 在batch的第shape[2]上,左右翻轉
		batch_flip = np.flip(batch, 2)
	else:
		# 不翻轉
		batch_flip = batch

	# 在翻轉的基礎上旋轉
	if random[1] < 0.5:
		# 逆時針旋轉90度
		batch_rot = np.rot90(batch_flip, 1, [1, 2])
	else:
		batch_rot = batch_flip

Evaluation & Test

def eval(self, config):
    print("\nPrepare Data...\n")
    paths = prepare_data(config)
    data_num = len(paths)

    avg_time = 0
    avg_pasn = 0
    print("\nNow Start Testing...\n")
    for idx in range(data_num):
        input_, label_ = get_image(paths[idx], config.scale, config.matlab_bicubic)

        images_shape = input_.shape
        labels_shape = label_.shape
        self.build_model(images_shape, labels_shape)
        tf.global_variables_initializer().run(session=self.sess)

        self.load(config.checkpoint_dir)

        time_ = time.time()
        result = self.sess.run([self.pred], feed_dict={self.images: input_ / 255.0})
        avg_time += time.time() - time_

        # import matlab.engine
        # eng = matlab.engine.start_matlab()
        # time_ = time.time()
        # result = np.asarray(eng.imresize(matlab.double((input_[0, :] / 255.0).tolist()), config.scale, 'bicubic'))
        # avg_time += time.time() - time_

        self.sess.close()
        tf.reset_default_graph()
        self.sess = tf.Session()

        x = np.squeeze(result) * 255.0
        x = np.clip(x, 0, 255)
        psnr = PSNR(x, label_[0], config.scale)
        avg_pasn += psnr

        print("image: %d/%d, time: %.4f, psnr: %.4f" % (idx, data_num, time.time() - time_, psnr))

        if not os.path.isdir(os.path.join(os.getcwd(), config.result_dir)):
            os.makedirs(os.path.join(os.getcwd(), config.result_dir))
        imsave(x[:, :, ::-1], config.result_dir + '/%d.png' % idx)

    print("Avg. Time:", avg_time / data_num)
    print("Avg. PSNR:", avg_pasn / data_num)

def test(self, config):
    print("\nPrepare Data...\n")
    paths = prepare_data(config)
    data_num = len(paths)

    avg_time = 0
    print("\nNow Start Testing...\n")
    for idx in range(data_num):
        input_ = imread(paths[idx])
        input_ = input_[:, :, ::-1]
        input_ = input_[np.newaxis, :]

        images_shape = input_.shape
        labels_shape = input_.shape * np.asarray([1, self.scale, self.scale, 1])
        self.build_model(images_shape, labels_shape)
        tf.global_variables_initializer().run(session=self.sess)

        self.load(config.checkpoint_dir)

        time_ = time.time()
        result = self.sess.run([self.pred], feed_dict={self.images: input_ / 255.0})
        avg_time += time.time() - time_

        self.sess.close()
        tf.reset_default_graph()
        self.sess = tf.Session()

        x = np.squeeze(result) * 255.0
        x = np.clip(x, 0, 255)
        x = x[:, :, ::-1]
        checkimage(np.uint8(x))

        if not os.path.isdir(os.path.join(os.getcwd(), config.result_dir)):
            os.makedirs(os.path.join(os.getcwd(), config.result_dir))
        imsave(x, config.result_dir + '/%d.png' % idx)

    print("Avg. Time:", avg_time / data_num)
    
    
def rgb2ycbcr(img):
	"""
	將RGB圖轉化為YCbCr顏色格式的圖片

	:param img: RGB圖,(H,W,3)
	:return: (H,W)
	"""
	y = 16 + (65.481 * img[:, :, 0]) \
		+ (128.553 * img[:, :, 1]) \
		+ (24.966 * img[:, :, 2])
	return y / 255


def PSNR(target, ref, scale):
	"""
	影象質量指標函式,PSNR指標
	PSNR=-10*LOG(MSE/255**2)
	MSE denotes mean square entropy

	:param target: 目標圖
	:param ref:  待比較的圖片
	:param scale:
	:return: scalar
	"""
	target_data = np.array(target, dtype=np.float32)
	ref_data = np.array(ref, dtype=np.float32)

	# 將RGB圖轉化為YCbCr顏色格式再求PSNR
	target_y = rgb2ycbcr(target_data)
	ref_y = rgb2ycbcr(ref_data)
	diff = ref_y - target_y

	shave = scale
	diff = diff[shave:-shave, shave:-shave]

	mse = np.mean((diff / 255) ** 2)
	if mse == 0:
		return 100

	return -10 * math.log10(mse)