1. 程式人生 > >使用自動編解碼器網路實現圖片噪音去除

使用自動編解碼器網路實現圖片噪音去除

在前面章節中,我們一再看到,訓練或使用神經網路進行預測時,我們需要把資料轉換成張量。例如要把圖片輸入卷積網路,我們需要把圖片轉換成二維張量,如果要把句子輸入LSTM網路,我們需要把句子中的單詞轉換成one-hot-encoding向量。
這種資料型別轉換往往是由人設計的,我們本節介紹一種神經網路,它能夠為輸入資料自動找到合適的資料轉換方法,它自動把資料轉換成某種格式的張量,然後又能把相應張量還原回原有形態,這種網路就叫自動編解碼器。

自動編解碼器的功能很像加解密系統,對加密而言,當把明文進行加密後,形成的密文是一種隨機字串,再把密文解密後就可以得到明文,解密後的資料必須與加密前的完全一模一樣。自動編解碼器會把輸入的資料,例如是圖片轉換成給定維度的張量,例如一個含有16個元素的一維向量,解碼後它會把對應的含有16個元素的一維向量轉換為原有圖片,不過轉換後的圖片與原圖片不一定完全一樣,但是圖片內容絕不會有重大改變。
自動編解碼器分為兩部分,一部分叫encoder,它負責把資料轉換成固定格式,從數學上看,encoder相當於一個函式,被編碼的資料相當於輸入引數,編碼後的張量相當於函式輸出: ,其中f對應encoder,x對應要編碼的資料,例如圖片,z是編碼後的結果。
另一部分叫decoder,也就是把編碼器編碼的結果還原為原有資料,用數學來表達就是: ,函式g相當於解碼器,它的輸入是編碼器輸出結果, 是解碼器還原結果,它與輸入編碼器的資料可能有差異,但主要內容會保持不變,如圖10-1:

10-1.png

圖10-1 編解碼器執行示意圖

如上圖,手寫數字圖片7經過編碼器後,轉換成給定維度的張量,例如含有16個元素的一維張量,然後經過解碼器處理後還原成一張手寫數字圖片7,還原的圖片與輸入的圖片影象顯示上有些差異,但是他們都能表達手寫數字7這一含義。
程式碼是對原理最好的解釋,我們看看實現過程:

from keras.layers import Dense, Input
from keras.layers import Conv2D, Flatten
from keras.layers import Reshape, Conv2DTranspose
from keras.models import Model
from keras.datasets import mnist
from keras.utils import plot_model
from keras import backend as K

import numpy as np
import matplotlib.pyplot as plt

#載入手寫數字圖片資料
(x_train, _), (x_test, _) = mnist.load_data()
image_size = x_train.shape[1]
#把圖片大小統一轉換成28*28,並把畫素點值都轉換為[0,1]之間
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_test = np.reshape(x_test, [-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
接下來我們構建自動編解碼器網路:
#構建解碼器
latent_inputs = Input(shape=(latent_dim, ), name='decoder_input')
x = Dense(shape[1] * shape[2] * shape[3])(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)

'''
使用Conv2DTranspose做卷積操作的逆操作。相應的Conv2D做怎樣的計算操作,該網路層就逆著來
'''
for filters in layer_filters[::-1]:
  x = Conv2DTranspose(filters = filters, kernel_size = kernel_size, 
                     activation='relu', strides = 2, padding='same')(x)

#還原輸入
outputs = Conv2DTranspose(filters = 1, kernel_size = kernel_size, 
                         activation='sigmoid', padding='same', 
                          name='decoder_output')(x)

decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()

我們把編碼器和解碼器前後相連,於是資料從編碼器輸入,編碼器將資料進行計算編號後所得的輸出直接傳給解碼器,解碼器進行相對於編碼器的逆運算最後得到類似於輸入編碼器的資料,相應程式碼如下:


'''
將編碼器和解碼器前後相連,資料從編碼器輸入,編碼器運算後把結果直接傳遞給解碼器,
解碼器進行編碼器的逆運算,最後輸出與資料輸入時相似的結果
'''
autoencoder = Model(inputs, decoder(encoder(inputs)),
                   name='autoencoder')

autoencoder.compile(loss='mse', optimizer='adam')
'''

在訓練網路時,輸入資料是x_train,對應標籤也是x_train,這意味著我們希望網路將輸出儘可能的調整成與輸入一致

'''
autoencoder.fit(x_train, x_train, validation_data=(x_test, x_test), epochs = 1, 
                batch_size = batch_size)

網路訓練好後,我們把圖片輸入網路,編碼器把圖片轉換為含有16個元素的一維向量,然後向量輸入解碼器,解碼器把向量還原為一張二維圖片,相應程式碼如下:

'''
把手寫數字圖片輸入編碼器然後再通過解碼器,檢驗輸出後的影象與輸出時的影象是否相似
'''
x_decoded = autoencoder.predict(x_test)
imgs = np.concatenate([x_test[:8], x_decoded[:8]])
imgs = imgs.reshape((4, 4, image_size, image_size))
imgs = np.vstack([np.hstack(i) for i in imgs])
plt.figure()
plt.axis('off')
plt.title('Input image: first and second rows, Decoded: third and forth rows')
plt.imshow(imgs, interpolation='none', cmap = 'gray')
plt.savefig('input_and_decoded.png')
plt.show()

上面程式碼執行後結果如圖10-2:

10-2.png

圖10-2 程式碼執行結果
上面顯示圖片中,前兩行是輸入編解碼器的手寫數字圖片,後兩行是經過編碼然後還原後的圖片,如果仔細看我們可以發現兩者非常相像,但並不完全一樣,我們看第一行最後一個數字0和解碼後第三行最後一個數字0,兩者有比較明顯差異,但都會被解讀成數字0.
在程式碼中需要注意的是,構建解碼器時我們使用了一個類叫Conv2DTranspose,它與Conv2D對應,是後者的反操作,如果把Conv2D看做對輸入資料的壓縮或加密,那麼Conv2DTranspose是對資料的解壓或解密。
另外還需要注意的是,因為我們網路層較少,因此訓練時只需要一次迴圈就好,如果網路層多的話,我們需要增加迴圈次數才能使得網路有良好的輸出效果。

2.使用編解碼器去除圖片噪音

在八零年代,改革開放不久後,一種‘稀有’的家電悄悄潛入很多家庭,那就是錄影機。你把一盤錄影帶推入機器,在電視上就可以把內容播放出來,有一些錄影帶它的磁帶遭到破壞的話,播放時畫面會飄散一系列‘雪花’,我們將那稱之為畫面‘噪音’。
當圖片含有‘噪音’時,圖片表現為含有很多花點,如圖10-3所示:

10-3.jpg

圖10-3 含有噪音的圖片

在訊號處理這一學科分支中,有很大一部分就在於研究如何去噪,幸運的是通過編解碼網路也能夠實現圖片噪音去除的效果。本節我們先給手寫數字圖片增加噪音,使得圖片變得很難識別,然後我們再使用編解碼網路去除圖片噪音,讓圖片回覆原狀。
圖片噪音本質上是在畫素點上新增一些隨機值,這裡我們使用高斯分佈產生隨機值,其數學公式如下:

螢幕快照 2018-11-20 上午11.42.04.png

它有兩個決定性引數,分別是μ 和 σ,只要使得這兩個引數取不同的值,我們就可以得到相應分佈的隨機數,其中μ 稱之為均值, σ稱之為方差,我們看看如何使用程式碼實現圖片加噪,然後構建編解碼網路去噪音:

#使用高斯分佈產生圖片噪音
np.random.seed(1337)
#使用高斯分佈函式生成隨機數,均值0.5,方差0.5
noise = np.random.normal(loc=0.5, scale=0.5, size=x_train.shape)
x_train_noisy = x_train + noise

noise = np.random.normal(loc=0.5, scale=0.5, size=x_test.shape)
x_test_noisy = x_test + noise
#把畫素點取值範圍轉換到[0,1]間
x_train_noisy = np.clip(x_train_noisy, 0., 1.)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)
上面的程式碼先使用高斯函式產生隨機數,然後加到畫素點上從而形成圖片噪音。接著我們看如何構建編解碼器實現圖片去噪:
#構造編解碼網路,以下程式碼與上一小節程式碼大部分相同
input_shape = (image_size, image_size, 1)
batch_size = 32
kernel_size = 3
latent_dim = 16
layer_filters = [32, 64]
#構造編碼器
inputs = Input(shape=input_shape, name='encoder_input')
x = inputs
for filters in layer_filters:
  x = Conv2D(filters = filters, kernel_size = kernel_size, strides = 2,
            activation='relu', padding='same')(x)

shape = K.int_shape(x)
x = Flatten()(x)
latent = Dense(latent_dim, name='latent_vector')(x)
encoder = Model(inputs, latent, name='encoder')

#構造解碼器
latent_inputs = Input(shape=(latent_dim, ), name='decoder_input')
x = Dense(shape[1] * shape[2] * shape[3])(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)

for filters in layer_filters[::-1]:
  x = Conv2DTranspose(filters = filters, kernel_size = kernel_size, strides = 2,
                     activation='relu', padding='same')(x)
  
outputs = Conv2DTranspose(filters=1, kernel_size=kernel_size, padding='same',
                         activation='sigmoid', name='decoder_output')(x)
decoder = Model(latent_inputs, outputs, name='decoder')


#將編碼器和解碼器前後相連
autoencoder = Model(inputs, decoder(encoder(inputs)), name='autoencoder')
autoencoder.compile(loss='mse', optimizer='adam')
#輸入資料是有噪音圖片,對應結果是無噪音圖片
autoencoder.fit(x_train_noisy, x_train, validation_data=(x_test_noisy, x_test),
               epochs = 10, batch_size = batch_size)

程式碼中值得注意的是,訓練網路時,訓練資料時含有噪音的圖片,對應結果是沒有噪音的圖片,也就是我們希望網路能通過學習自動掌握去噪功能,訓練完成後,我們把測試圖片輸入網路,看看噪音去除效果:

x_decoded = autoencoder.predict(x_test_noisy)

rows, cols = 3, 9
num = rows * cols
imgs = np.concatenate([x_test[:num], x_test_noisy[:num],
                      x_decoded[:num]])
imgs = imgs.reshape((rows * 3, cols, image_size, image_size))
imgs = np.vstack(np.split(imgs, rows, axis=1))
imgs = imgs.reshape((rows * 3, -1, image_size, image_size))
imgs = np.vstack([np.hstack(i) for i in imgs])
imgs = (imgs * 255).astype(np.uint8)
plt.figure()
plt.axis('off')
plt.title('Original images: top rows'
          'Corrupted images: middle rows'
         'Denoised Input: third rows')
plt.imshow(imgs, interpolation='none', cmap='gray')
plt.show()

上面程式碼執行後如圖10-4所示:

10-4.png

圖10-4 網路去噪效果
從上圖看,第一行是原圖,第二行是加了噪音的圖片,第三行是網路去除噪音後的圖片。從上圖看,網路去噪的效果還是比較完美的。

更多內容,請點選進入csdn學院

更多技術資訊,包括作業系統,編譯器,面試演算法,機器學習,人工智慧,請關照我的公眾號:
這裡寫圖片描述