1. 程式人生 > >Keras下實現 Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising

Keras下實現 Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising

使用Keras實現 Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising 這篇文章。

generator_data.py

import glob
import os
import cv2
import numpy as np
from multiprocessing import Pool

patch_size, stride = 40, 10
aug_times = 1

def data_aug(img, mode=0):
    
    if mode == 0:
        return img
    elif mode == 1:
        return np.flipud(img)
    elif mode == 2:
        return np.rot90(img)
    elif mode == 3:
        return np.flipud(np.rot90(img))
    elif mode == 4:
        return np.rot90(img, k=2)
    elif mode == 5:
        return np.flipud(np.rot90(img, k=2))
    elif mode == 6:
        return np.rot90(img, k=3)
    elif mode == 7:
        return np.flipud(np.rot90(img, k=3))
    
def gen_patches(file_name):

    # read image
    img = cv2.imread(file_name, 0)  # gray scale
    h, w = img.shape
    scales = [1, 0.9, 0.8, 0.7]
    patches = []

    for s in scales:
        h_scaled, w_scaled = int(h*s),int(w*s)
        img_scaled = cv2.resize(img, (h_scaled,w_scaled), interpolation=cv2.INTER_CUBIC)
        # extract patches
        for i in range(0, h_scaled-patch_size+1, stride):
            for j in range(0, w_scaled-patch_size+1, stride):
                x = img_scaled[i:i+patch_size, j:j+patch_size]
                # data aug
                for k in range(0, aug_times):
                    x_aug = data_aug(x, mode=np.random.randint(0,8))
                    patches.append(x_aug)
    
    return patches

if __name__ == '__main__':
    # parameters
    src_dir = './data/Train500/'
    save_dir = './data/npy_dataa/'
    file_list = glob.glob(src_dir+'*.png')  # get name list of all .png files
    num_threads = 16    
    print('Start...')
    # initrialize
    res = []
    # generate patches
    for i in range(0,len(file_list),num_threads):
        # use multi-process to speed up
        p = Pool(num_threads)
        patch = p.map(gen_patches,file_list[i:min(i+num_threads,len(file_list))])
        #patch = p.map(gen_patches,file_list[i:i+num_threads])
        for x in patch:
            res += x
        
        print('Picture '+str(i)+' to '+str(i+num_threads)+' are finished...')
    
    # save to .npy
    res = np.array(res, dtype='uint8')
    print('Shape of result = ' + str(res.shape))
    print('Saving data...')
    if not os.path.exists(save_dir):
            os.mkdir(save_dir)
    np.save(save_dir+'clean_patches.npy', res)
    print('Done.')       

models.py


from keras.models import Model
from keras.layers import  Input,Conv2D,BatchNormalization,Activation,Subtract

def DnCNN():
    
    inpt = Input(shape=(None,None,1))
    # 1st layer, Conv+relu
    x = Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding='same')(inpt)
    x = Activation('relu')(x)
    # 15 layers, Conv+BN+relu
    for i in range(15):
        x = Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding='same')(x)
        x = BatchNormalization(axis=-1, epsilon=1e-8)(x)
        x = Activation('relu')(x)   
    # last layer, Conv
    x = Conv2D(filters=1, kernel_size=(3,3), strides=(1,1), padding='same')(x)
    x = Subtract()([inpt, x])   # input - noise
    model = Model(inputs=inpt, outputs=x)
    model.summary()
    return model

main.py

import argparse
import logging
import os,glob
import PIL.Image as Image
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import models
import tensorflow as tf
import time
import datetime
import math
from keras.callbacks import CSVLogger, ModelCheckpoint, LearningRateScheduler
from keras.models import load_model
from keras.optimizers import Adam
from skimage.measure import compare_psnr
from keras.callbacks import TensorBoard
from keras import backend as K



parser = argparse.ArgumentParser()
parser.add_argument('--model', default='DnCNN', type=str, help='choose a type of model')
parser.add_argument('--batch_size', default=64, type=int, help='batch size')
parser.add_argument('--train_data', default='./data/npy_data/clean_patches.npy', type=str, help='path of train data')
parser.add_argument('--test_dir', default='./data/Test/BSD68', type=str, help='directory of test dataset')
parser.add_argument('--sigma', default=25, type=int, help='noise level')
parser.add_argument('--epoch', default=75, type=int, help='number of train epoches')
parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate for Adam')
parser.add_argument('--save_every', default= 5,type=int, help='save model at every x epoches')
parser.add_argument('--pretrain', default=None, type=str, help='path of pre-trained model')
parser.add_argument('--only_test', default=False, type=bool, help='train and test or only test')
parser.add_argument('--img_test', default='./data/img_test/', type=str, help='directory of test dataset')
parser.add_argument('--img_out', default='./data/img_out/', type=str, help='directory of test dataset')
args = parser.parse_args()


if not args.only_test:
    save_dir = './snapshot/save_'+ args.model + '_' + 'sigma' + str(args.sigma) + '_' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    logging.basicConfig(level=logging.INFO,format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
                    datefmt='%Y %H:%M:%S',
                    filename=save_dir+'info.log',
                    filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(name)-6s: %(levelname)-6s %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)
    logging.info(args)
else:
    save_dir = '/'.join(args.pretrain.split('/')[:-1]) + '/'


def load_train_data():
    logging.info('loading train data...')
   
    data = np.load(args.train_data)
    logging.info('Size of train data: ({}, {}, {})'.format(data.shape[0],data.shape[1],data.shape[2]))
    return data


def log(*args,**kwargs):
    
    print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"),*args,**kwargs)




def train_datagen(y_, batch_size = 64):
    
    indices = list(range(y_.shape[0]))
    while(True):
        np.random.shuffle(indices)#打亂順序
        for i in range(0, len(indices), batch_size):
            ge_batch_y = y_[indices[i:i+batch_size]]
           
            noise =  np.random.normal(0, args.sigma/255.0, ge_batch_y.shape)  
         
            ge_batch_x = ge_batch_y + noise
            yield ge_batch_x, ge_batch_y



      
#訓練
def train():
    print ("-----start train-----")
    data = load_train_data()  #載入.npy檔案
    data = data.reshape((data.shape[0],data.shape[1],data.shape[2],1))
    data = data.astype('float32')/255.0  #進行資料轉換
   
    if args.pretrain:   
        model = load_model(args.pretrain, compile=False)
    else:   
        if args.model == 'DnCNN': 
            model = models.DnCNN()
    start = time.time()
    tensorboard = TensorBoard(log_dir='E:/python/Image Denoing code/Image Denoising/DnCNN-keras/tensorboard/',  write_graph=True, write_images=True)
   
    model.compile(optimizer=Adam(), loss=['mse'])
   
    ckpt = ModelCheckpoint(save_dir+'/model_{epoch:02d}.h5', monitor='val_loss', 
                    verbose=0, period=args.save_every)
    
    csv_logger = CSVLogger(save_dir+'/loss_MSE.csv', append=True, separator=',')
    
 
   
    history = model.fit_generator(train_datagen(data, batch_size=args.batch_size),
                    steps_per_epoch=2000, epochs=args.epoch, verbose=1, 
                    callbacks=[ckpt, csv_logger,tensorboard])
    end = time.time()
    print ("train time:",end - start)    
    plt.plot(history.history['loss'])
    plt.title("model loss")
    plt.ylabel("loss")
    plt.xlabel("epoch")
    plt.legend(["train"],loc="upper left")
    plt.show()
    return model

def test(model):
    print ("-----start test-----")
    out_dir = save_dir + args.test_dir.split('/')[-1] + '/'
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    #定義空列表
    name = []
    psnr = []
    x=[]
    #讀取測試集的全部影象
    start = time.time()
    file_list = glob.glob('{}/*.png'.format(args.test_dir))
    for index,file in enumerate(file_list):
        x.append(index)
        #從測試檔案中讀取影象
        img_clean = np.array(Image.open(file), dtype='float32') / 255.0
        #給測試影象新增噪聲
        img_test = img_clean + np.random.normal(0, args.sigma/255.0, img_clean.shape)
        #資料型別轉換
        img_test = img_test.astype('float32')
        #預測
        x_test = img_test.reshape(1, img_test.shape[0], img_test.shape[1], 1)
        #對img_test進行去噪
        y_predict = model.predict(x_test,batch_size=1024)
        #轉換為同一型別
        img_out = y_predict.reshape(img_clean.shape)
        #把小於0的數變成0,大於1的變成1
        img_out = np.clip(img_out, 0, 1)
        #計算PSNR,SSIM
        psnr_noise, psnr_denoised = compare_psnr(img_clean, img_test), compare_psnr(img_clean, img_out)
       
        psnr.append(psnr_denoised)
      
        #儲存影象
        filename = file.split('/')[-1].split('.')[0]  #獲取影象的名字
        name.append(filename)
        #呼叫Image庫,陣列歸一化
        img_test = Image.fromarray((img_test*255).astype('uint8'))
        img_test.save(args.img_test +'sigma'+'{}_psnr{:.2f}.png'.format(index, psnr_noise))
        img_out = Image.fromarray((img_out*255).astype('uint8')) 
        img_out.save( args.img_out+'{}_psnr{:.2f}.png'.format(index,psnr_denoised))
    plt.plot(x, psnr)  
    psnr_avg = sum(psnr)/len(psnr)
    name.append('Average')
    psnr.append(psnr_avg)
    print ("-----PSNR-----")
    print("Average PSNR = {0:.4f}".format(psnr_avg))
    pd.DataFrame({'name':np.array(name), 'psnr':np.array(psnr)}).to_csv(out_dir+'/PSNR.csv', index=True)
    end = time.time()
    print ("test time:",end - start)
    plt.show()

if __name__ == '__main__': 
    
    if args.only_test:
        model = load_model(args.pretrain, compile=False)
        test(model)
    else:
        with tf.device('/gpu:0'):
            model = train()
            print ("-----train  end-----")
            test(model)