Keras下實現 Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising
阿新 • • 發佈:2018-11-21
使用Keras實現 Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising 這篇文章。
generator_data.py
import glob import os import cv2 import numpy as np from multiprocessing import Pool patch_size, stride = 40, 10 aug_times = 1 def data_aug(img, mode=0): if mode == 0: return img elif mode == 1: return np.flipud(img) elif mode == 2: return np.rot90(img) elif mode == 3: return np.flipud(np.rot90(img)) elif mode == 4: return np.rot90(img, k=2) elif mode == 5: return np.flipud(np.rot90(img, k=2)) elif mode == 6: return np.rot90(img, k=3) elif mode == 7: return np.flipud(np.rot90(img, k=3)) def gen_patches(file_name): # read image img = cv2.imread(file_name, 0) # gray scale h, w = img.shape scales = [1, 0.9, 0.8, 0.7] patches = [] for s in scales: h_scaled, w_scaled = int(h*s),int(w*s) img_scaled = cv2.resize(img, (h_scaled,w_scaled), interpolation=cv2.INTER_CUBIC) # extract patches for i in range(0, h_scaled-patch_size+1, stride): for j in range(0, w_scaled-patch_size+1, stride): x = img_scaled[i:i+patch_size, j:j+patch_size] # data aug for k in range(0, aug_times): x_aug = data_aug(x, mode=np.random.randint(0,8)) patches.append(x_aug) return patches if __name__ == '__main__': # parameters src_dir = './data/Train500/' save_dir = './data/npy_dataa/' file_list = glob.glob(src_dir+'*.png') # get name list of all .png files num_threads = 16 print('Start...') # initrialize res = [] # generate patches for i in range(0,len(file_list),num_threads): # use multi-process to speed up p = Pool(num_threads) patch = p.map(gen_patches,file_list[i:min(i+num_threads,len(file_list))]) #patch = p.map(gen_patches,file_list[i:i+num_threads]) for x in patch: res += x print('Picture '+str(i)+' to '+str(i+num_threads)+' are finished...') # save to .npy res = np.array(res, dtype='uint8') print('Shape of result = ' + str(res.shape)) print('Saving data...') if not os.path.exists(save_dir): os.mkdir(save_dir) np.save(save_dir+'clean_patches.npy', res) print('Done.')
models.py
from keras.models import Model from keras.layers import Input,Conv2D,BatchNormalization,Activation,Subtract def DnCNN(): inpt = Input(shape=(None,None,1)) # 1st layer, Conv+relu x = Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding='same')(inpt) x = Activation('relu')(x) # 15 layers, Conv+BN+relu for i in range(15): x = Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding='same')(x) x = BatchNormalization(axis=-1, epsilon=1e-8)(x) x = Activation('relu')(x) # last layer, Conv x = Conv2D(filters=1, kernel_size=(3,3), strides=(1,1), padding='same')(x) x = Subtract()([inpt, x]) # input - noise model = Model(inputs=inpt, outputs=x) model.summary() return model
main.py
import argparse import logging import os,glob import PIL.Image as Image import matplotlib.pyplot as plt import numpy as np import pandas as pd import models import tensorflow as tf import time import datetime import math from keras.callbacks import CSVLogger, ModelCheckpoint, LearningRateScheduler from keras.models import load_model from keras.optimizers import Adam from skimage.measure import compare_psnr from keras.callbacks import TensorBoard from keras import backend as K parser = argparse.ArgumentParser() parser.add_argument('--model', default='DnCNN', type=str, help='choose a type of model') parser.add_argument('--batch_size', default=64, type=int, help='batch size') parser.add_argument('--train_data', default='./data/npy_data/clean_patches.npy', type=str, help='path of train data') parser.add_argument('--test_dir', default='./data/Test/BSD68', type=str, help='directory of test dataset') parser.add_argument('--sigma', default=25, type=int, help='noise level') parser.add_argument('--epoch', default=75, type=int, help='number of train epoches') parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate for Adam') parser.add_argument('--save_every', default= 5,type=int, help='save model at every x epoches') parser.add_argument('--pretrain', default=None, type=str, help='path of pre-trained model') parser.add_argument('--only_test', default=False, type=bool, help='train and test or only test') parser.add_argument('--img_test', default='./data/img_test/', type=str, help='directory of test dataset') parser.add_argument('--img_out', default='./data/img_out/', type=str, help='directory of test dataset') args = parser.parse_args() if not args.only_test: save_dir = './snapshot/save_'+ args.model + '_' + 'sigma' + str(args.sigma) + '_' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '/' if not os.path.exists(save_dir): os.makedirs(save_dir) logging.basicConfig(level=logging.INFO,format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', datefmt='%Y %H:%M:%S', filename=save_dir+'info.log', filemode='w') console = logging.StreamHandler() console.setLevel(logging.INFO) formatter = logging.Formatter('%(name)-6s: %(levelname)-6s %(message)s') console.setFormatter(formatter) logging.getLogger('').addHandler(console) logging.info(args) else: save_dir = '/'.join(args.pretrain.split('/')[:-1]) + '/' def load_train_data(): logging.info('loading train data...') data = np.load(args.train_data) logging.info('Size of train data: ({}, {}, {})'.format(data.shape[0],data.shape[1],data.shape[2])) return data def log(*args,**kwargs): print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"),*args,**kwargs) def train_datagen(y_, batch_size = 64): indices = list(range(y_.shape[0])) while(True): np.random.shuffle(indices)#打亂順序 for i in range(0, len(indices), batch_size): ge_batch_y = y_[indices[i:i+batch_size]] noise = np.random.normal(0, args.sigma/255.0, ge_batch_y.shape) ge_batch_x = ge_batch_y + noise yield ge_batch_x, ge_batch_y #訓練 def train(): print ("-----start train-----") data = load_train_data() #載入.npy檔案 data = data.reshape((data.shape[0],data.shape[1],data.shape[2],1)) data = data.astype('float32')/255.0 #進行資料轉換 if args.pretrain: model = load_model(args.pretrain, compile=False) else: if args.model == 'DnCNN': model = models.DnCNN() start = time.time() tensorboard = TensorBoard(log_dir='E:/python/Image Denoing code/Image Denoising/DnCNN-keras/tensorboard/', write_graph=True, write_images=True) model.compile(optimizer=Adam(), loss=['mse']) ckpt = ModelCheckpoint(save_dir+'/model_{epoch:02d}.h5', monitor='val_loss', verbose=0, period=args.save_every) csv_logger = CSVLogger(save_dir+'/loss_MSE.csv', append=True, separator=',') history = model.fit_generator(train_datagen(data, batch_size=args.batch_size), steps_per_epoch=2000, epochs=args.epoch, verbose=1, callbacks=[ckpt, csv_logger,tensorboard]) end = time.time() print ("train time:",end - start) plt.plot(history.history['loss']) plt.title("model loss") plt.ylabel("loss") plt.xlabel("epoch") plt.legend(["train"],loc="upper left") plt.show() return model def test(model): print ("-----start test-----") out_dir = save_dir + args.test_dir.split('/')[-1] + '/' if not os.path.exists(out_dir): os.makedirs(out_dir) #定義空列表 name = [] psnr = [] x=[] #讀取測試集的全部影象 start = time.time() file_list = glob.glob('{}/*.png'.format(args.test_dir)) for index,file in enumerate(file_list): x.append(index) #從測試檔案中讀取影象 img_clean = np.array(Image.open(file), dtype='float32') / 255.0 #給測試影象新增噪聲 img_test = img_clean + np.random.normal(0, args.sigma/255.0, img_clean.shape) #資料型別轉換 img_test = img_test.astype('float32') #預測 x_test = img_test.reshape(1, img_test.shape[0], img_test.shape[1], 1) #對img_test進行去噪 y_predict = model.predict(x_test,batch_size=1024) #轉換為同一型別 img_out = y_predict.reshape(img_clean.shape) #把小於0的數變成0,大於1的變成1 img_out = np.clip(img_out, 0, 1) #計算PSNR,SSIM psnr_noise, psnr_denoised = compare_psnr(img_clean, img_test), compare_psnr(img_clean, img_out) psnr.append(psnr_denoised) #儲存影象 filename = file.split('/')[-1].split('.')[0] #獲取影象的名字 name.append(filename) #呼叫Image庫,陣列歸一化 img_test = Image.fromarray((img_test*255).astype('uint8')) img_test.save(args.img_test +'sigma'+'{}_psnr{:.2f}.png'.format(index, psnr_noise)) img_out = Image.fromarray((img_out*255).astype('uint8')) img_out.save( args.img_out+'{}_psnr{:.2f}.png'.format(index,psnr_denoised)) plt.plot(x, psnr) psnr_avg = sum(psnr)/len(psnr) name.append('Average') psnr.append(psnr_avg) print ("-----PSNR-----") print("Average PSNR = {0:.4f}".format(psnr_avg)) pd.DataFrame({'name':np.array(name), 'psnr':np.array(psnr)}).to_csv(out_dir+'/PSNR.csv', index=True) end = time.time() print ("test time:",end - start) plt.show() if __name__ == '__main__': if args.only_test: model = load_model(args.pretrain, compile=False) test(model) else: with tf.device('/gpu:0'): model = train() print ("-----train end-----") test(model)