1. 程式人生 > >深度學習結合非區域性均值濾波的影象去噪演算法

深度學習結合非區域性均值濾波的影象去噪演算法

其實這是半年之前完成的內容,一直懶著沒有總結,今天看了看程式碼,發覺再不總結自己以後都看不懂了,故整理如下。

非區域性均值是一種基於塊匹配來確定濾波權值的。即先確定一個塊的大小,例如7x7,然後在確定一個搜尋區域,例如15x15,在15x15這個搜尋區域中的每一個點,計算7x7的視窗與當前濾波點7x7視窗的相似性(使用絕對差和SAD,一般而言,視窗中各點的差值還需要乘以經高斯核生成的權重引數,離中心點越近,權重值越大一些),然後根據相似性值使用指數函式生成視窗中心點的權重引數,相似性越高,該中心點的權重越大,最後各中心點的加權平均就是最終濾波影象,能獲得很好的視覺效果。

非區域性均值的成功之處主要在於充分利用了塊的相似性,而後續步驟由相似性計算對應權重值,按照經驗使用指數函式,其引數h有著至關重要的作用,許多論文也是在h上面做改進。如果我們跳出加權平均和指數函式的思路,完全可以將含噪影象所有相鄰點的畫素值、相似性值、距離等做為輸入送給深度學習網路,將原影象值作為輸出進行訓練啊,訓練好的模型就可以直接用於濾波。

下面附一個簡化版的python程式碼,經實測改進後的演算法比原生的非區域性均值濾波要好,裡面的網路模型過於簡單,想提升效果的自己修改調優吧。

注意使用的是python3環境

#coding:utf8
import cv2, datetime,sys,glob
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm

from keras.models import Sequential, model_from_json
from keras.layers import Dense, Activation,Dropout,Flatten,Merge
from keras.callbacks import EarlyStopping
from keras.layers.convolutional import Convolution2D,Convolution3D

def psnr(A, B):
    return 10*np.log(255*255.0/(((A.astype(np.float)-B)**2).mean()))/np.log(10)
def double2uint8(I, ratio=1.0):
    return np.clip(np.round(I*ratio), 0, 255).astype(np.uint8)

def GetNlmData(I, templateWindowSize=4,  searchWindowSize=9):
    f = int(templateWindowSize / 2)
    t = int(searchWindowSize / 2)
    height, width = I.shape[:2]
    padLength = t + f
    I2 = np.pad(I, padLength, 'symmetric')
    I_ = I2[padLength - f:padLength + f + height, padLength - f:padLength + f + width]

    res = np.zeros((height, width, templateWindowSize+2, t+t+1, t+t+1))
    for i in range(-t, t + 1):
        for j in range(-t, t + 1):
            I2_ = I2[padLength + i - f:padLength + i + f + height, padLength + j - f:padLength + j + f + width]
            for kk in range(templateWindowSize):
                kernel = np.ones((2*kk+1, 2*kk+1))
                kernel = kernel/kernel.sum()
                res[:, :, kk, i+t, j+t] = cv2.filter2D((I2_-I_) ** 2, -1,  kernel)[f:f + height, f:f + width]
            res[:, :, -2, i+t, j+t] = I2_[f:f + height, f:f + width]-I
            res[:, :, -1, i+t, j+t] = np.exp(-np.sqrt(i**2+j**2))
    print(res.max(), res.min())
    return res

def zmTrain(trainX, trainY):
    model = Sequential()
    if 1:
        model.add(Dense(100, init='uniform', input_dim=trainX.shape[1]))
        model.add(Activation('relu'))
        model.add(Dense(50))
        model.add(Activation('relu'))
        model.add(Dense(1))
        model.compile(loss='mse', optimizer='adam', metrics=['accuracy'])
    else:
        with open('model.json', 'rb') as fd:
            model = model_from_json(fd.read())
            model.load_weights('weight.h5')
            model.compile(loss='msle', optimizer='adam', metrics=['accuracy'])

    early_stopping = EarlyStopping(monitor='val_loss', patience=5)
    hist =model.fit(trainX, trainY, batch_size=150, epochs=200, shuffle=True, verbose=2, validation_split=0.1
                    ,callbacks=[early_stopping])
    print(hist.history)

    res = model.predict(trainX)
    res = np.clip(np.round(res.ravel() * 255), 0, 255)
    print(psnr(res, trainY*255))
    return model
if __name__ == '__main__':
    sigma = 20.0
    if 1:                         #這部分程式碼用於訓練模型
        trainX = None
        trainY = None

        for d in glob.glob('./img/_*'):
            I = cv2.imread(d,0)
            I1 = double2uint8(I + np.random.randn(*I.shape) *sigma)
            data = GetNlmData(I1.astype(np.double)/255)
            s = data.shape
            data.resize((np.prod(s[:2]), np.prod(s[2:])))
            
            if trainX is None:
                trainX = data
                trainY = ((I.astype(np.double)-I1)/255).ravel()
            else:
                trainX = np.concatenate((trainX, data), axis=0)
                trainY = np.concatenate((trainY,  ((I.astype(np.double)-I1)/255).ravel()), axis=0)
            
        
        model = zmTrain(trainX, trainY)
        with open('model.json', 'wb') as fd:
            #fd.write(model.to_json())
            fd.write(bytes(model.to_json(),'utf8'))
        model.save_weights('weight.h5')
    if 1:                       #濾波
        with open('model.json', 'rb') as fd:
            model = model_from_json(fd.read().decode())
            model.load_weights('weight.h5')
        I = cv2.imread('lena.jpg', 0)
        I1 = double2uint8(I + np.random.randn(*I.shape) * sigma)

        data= GetNlmData(I1.astype(np.double)/255)
        s = data.shape
        data.resize((np.prod(s[:2]), np.prod(s[2:])))
        res = model.predict(data)
        res.resize(I.shape)
        res = np.clip(np.round(res*255 +I1), 0, 255)
        print('nwNLM PSNR', psnr(res, I))
        res = res.astype(np.uint8)
        cv2.imwrite('cvOut.bmp', res)