1. 程式人生 > >殘差網路ResNet網路原理及實現

殘差網路ResNet網路原理及實現

全文共1483字,5張圖,預計閱讀時間10分鐘。

作者介紹:

石曉文,中國人民大學資訊學院在讀研究生,美團外賣演算法實習生

簡書ID:石曉文的學習日記(https://www.jianshu.com/u/c5df9e229a67)

天善社群:https://www.hellobi.com/u/58654/articles

騰訊雲:https://cloud.tencent.com/developer/user/1622140

開發者頭條:https://toutiao.io/u/470599

論文地址:https://arxiv.org/pdf/1512.03385.pdf

引言-深度網路的退化問題

在深度神經網路訓練中,從經驗來看,隨著網路深度的增加,模型理論上可以取得更好的結果。但是實驗卻發現,深度神經網路中存在著退化問題(Degradation problem)。可以看到,在下圖中56層的網路比20層網路效果還要差。

640?wx_fmt=jpeg

上面的現象與過擬合不同,過擬合的表現是訓練誤差小而測試誤差大,而上面的圖片顯示訓練誤差和測試誤差都是56層的網路較大。

深度網路的退化問題至少說明深度網路不容易訓練。我們假設這樣一種情況,56層的網路的前20層和20層網路引數一模一樣,而後36層是一個恆等對映( identity mapping),即輸入x輸出也是x,那麼56層的網路的效果也至少會和20層的網路效果一樣,可是為什麼出現了退化問題呢?因此我們在訓練深層網路時,訓練方法肯定存在的一定的缺陷。

正是上面的這個有趣的假設,何凱明博士發明了殘差網路ResNet來解決退化問題!讓我們來一探究竟!

ResNet網路結構

ResNet中最重要的是殘差學習單元:

640?wx_fmt=jpeg

對於一個堆積層結構(幾層堆積而成)當輸入為x時其學習到的特徵記為H(x),現在我們希望其可以學習到殘差F(x)=H(x)-x,這樣其實原始的學習特徵是F(x)+x 。當殘差為0時,此時堆積層僅僅做了恆等對映,至少網路效能不會下降,實際上殘差不會為0,這也會使得堆積層在輸入特徵基礎上學習到新的特徵,從而擁有更好的效能。一個殘差單元的公式如下:

640?wx_fmt=jpeg

後面的x前面也需要經過引數Ws變換,從而使得和前面部分的輸出形狀相同,可以進行加法運算。

在堆疊了多個殘差單元后,我們的ResNet網路結構如下圖所示:

640?wx_fmt=jpeg

ResNet程式碼實戰

我們來實現一個mnist手寫數字識別的程式。程式碼中主要使用的是tensorflow.contrib.slim中定義的函式,slim作為一種輕量級的tensorflow庫,使得模型的構建,訓練,測試都變得更加簡單。卷積層、池化層以及全聯接層都可以進行快速的定義,非常方便。這裡為了方便使用,我們直接匯入slim。

import tensorflow.contrib.slim as slim

我們主要來看一下我們的網路結構。首先定義兩個殘差結構,第一個是輸入和輸出形狀一樣的殘差結構,一個是輸入和輸出形狀不一樣的殘差結構。

下面是輸入和輸出形狀相同的殘差塊,這裡slim.conv2d函式的輸入有三個,分別是輸入資料、卷積核數量、卷積核的大小,預設的話padding為SAME,即卷積後形狀不變,由於輸入和輸出形狀相同,因此我們可以在計算outputs時直接將兩部分相加。

def res_identity(input_tensor,conv_depth,kernel_shape,layer_name):
with tf.variable_scope(layer_name):
relu = tf.nn.relu(slim.conv2d(input_tensor,conv_depth,kernel_shape))
outputs = tf.nn.relu(slim.conv2d(relu,conv_depth,kernel_shape) + input_tensor)
return outputs

下面是輸入和輸出形狀不同的殘差塊,由於輸入和輸出形狀不同,因此我們需要對輸入也進行一個卷積變化,使二者形狀相同。ResNet作者建議可以用1*1的卷積層,stride=2來進行變換:

def res_change(input_tensor,conv_depth,kernel_shape,layer_name):
with tf.variable_scope(layer_name):
relu = tf.nn.relu(slim.conv2d(input_tensor,conv_depth,kernel_shape,stride=2))
input_tensor_reshape = slim.conv2d(input_tensor,conv_depth,[1,1],stride=2)
outputs = tf.nn.relu(slim.conv2d(relu,conv_depth,kernel_shape) + input_tensor_reshape)
return outputs

最後是整個網路結構,對於x的輸入,我們先進行一次卷積和池化操作,然後接入四個殘差塊,最後接兩層全聯接層得到網路的輸出。

def inference(inputs):
x = tf.reshape(inputs,[-1,28,28,1])
conv_1 = tf.nn.relu(slim.conv2d(x,32,[3,3])) #28 * 28 * 32
pool_1 = slim.max_pool2d(conv_1,[2,2]) # 14 * 14 * 32
block_1 = res_identity(pool_1,32,[3,3],'layer_2')
block_2 = res_change(block_1,64,[3,3],'layer_3')
block_3 = res_identity(block_2,64,[3,3],'layer_4')
block_4 = res_change(block_3,32,[3,3],'layer_5')
net_flatten = slim.flatten(block_4,scope='flatten')
fc_1 = slim.fully_connected(slim.dropout(net_flatten,0.8),200,activation_fn=tf.nn.tanh,scope='fc_1')
output = slim.fully_connected(slim.dropout(fc_1,0.8),10,activation_fn=None,scope='output_layer')
return output

完整的程式碼地址在:https://github.com/princewen/tensorflow_practice/tree/master/CV/ResNet

參考方式

1、論文:https://arxiv.org/pdf/1512.03385.pdf


2、https://blog.csdn.net/kaisa158/article/details/81096588?utm_source=blogxgwz4

原文連結:https://mp.weixin.qq.com/s/AsCF4dsuS-XTF1Yr_tKKTQ

查閱更為簡潔方便的分類文章以及最新的課程、產品資訊,請移步至全新呈現的“LeadAI學院官網”:

www.leadai.org

請關注人工智慧LeadAI公眾號,檢視更多專業文章

640?wx_fmt=jpeg

大家都在看

640.png?