TensorFlow函式之tf.nn.conv2d()(附程式碼詳解)
tf.nn.conv2d是TensorFlow裡面實現卷積的函式,是搭建卷積神經網路比較核心的一個方法。
函式格式:
tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu = Noen, name = None)
引數說明:
第一個引數input:指需要做卷積的輸入影象,它要求是一個4維的Tensor,型別為float32和float64之一,shape為[batch, in_height, in_width, in_channels]:
- batch:訓練時一個batch的圖片數量
- in_height:輸入影象的高度
- in_width:輸入影象的寬度
- in_channels:輸入影象的通道數,灰度影象則為1,彩色影象則為3
第二個引數filter:CNN卷積網路中的卷積核,要求是一個Tensor,型別和input型別相同,shape為[filter_height, filter_width, in_channels, out_channels]:
- filter_height:卷積核的高度
- filter_width:卷積核的寬度
- in_channels:影象的通道數,input的in_channels相同
- out_channels:
卷積核的個數第三個引數strides:不同維度上的步長,是一個長度為4的一維向量,[ 1, strides, strides, 1],第一維和最後一維的數字要求必須是1。因為卷積層的步長只對矩陣的長和寬有效。
第四個引數padding:string型別,表示卷積的形式,是否考慮邊界,值為“SAME”和“VALID”,"SAME"是考慮邊界,不足的時候用填充周圍,"VALID"則不考慮邊界。
第五個引數use_cudnn_on_gpu: bool型別,是否使用cudnn加速,預設為true。
該函式返回的就是我們常說的feature map,shape仍然是[batch, height, width, channels]
下邊通過例子來說明tf.nn.conv2d()函式的用法:
case1:輸入是1張 3*3 大小的圖片,影象通道數是5,卷積核是 1*1 大小,數量是1 ,步長是[1,1,1,1],最後得到一個 3*3 的feature map。1張圖最後輸出就是一個 shape為[1,3,3,1] 的tensor。
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import tensorflow as tf
Input = tf.Variable(tf.random_normal([1, 3, 3, 5]))
Filter = tf.Variable(tf.random_normal([1, 1, 5, 1]))
conv1 = tf.nn.conv2d(Input, Filter, strides=[1, 1, 1, 1], padding='VALID')
with tf.Session() as sess:
# 初始化變數
op_init = tf.global_variables_initializer()
sess.run(op_init)
print(sess.run(conv1))
執行結果:
case2:輸入是1張 3*3 大小的圖片,影象通道數是5,卷積核是 2*2大小,數量是1 ,步長是[1,1,1,1],padding 設定為“VALID”,最後得到一個 2*2的feature map。1張圖最後輸出就是一個 shape為[1,2,2,1] 的tensor。
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import tensorflow as tf
Input = tf.Variable(tf.random_normal([1, 3, 3, 5]))
Filter = tf.Variable(tf.random_normal([2, 2, 5, 1]))
conv2 = tf.nn.conv2d(Input, Filter, strides=[1, 1, 1, 1], padding='VALID')
with tf.Session() as sess:
# 初始化變數
op_init = tf.global_variables_initializer()
sess.run(op_init)
print(sess.run(conv2))
執行結果:
case3:將case2例子的padding值改為“SAME”,即考慮邊界。(輸入是1張 3*3 大小的圖片,影象通道數是5,卷積核是 2*2大小,數量是1 ,步長是[1,1,1,1],padding 設定為“SAME”,最後得到一個 3*3的feature map。1張圖最後輸出就是一個 shape為[1,3,3,1] 的tensor。)
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import tensorflow as tf
Input = tf.Variable(tf.random_normal([1, 3, 3, 5]))
Filter = tf.Variable(tf.random_normal([2, 2, 5, 1]))
conv3= tf.nn.conv2d(Input, Filter, strides=[1, 1, 1, 1], padding='SAME')
with tf.Session() as sess:
# 初始化變數
op_init = tf.global_variables_initializer()
sess.run(op_init)
print(sess.run(conv3))
執行結果為:
case4:輸入是2張 3*3 大小的圖片,影象通道數是5,卷積核是 2*2大小,數量是1,步長是[1,1,1,1],padding 設定為“SAME”,最後得到2個 3*3的feature map。1張圖最後輸出就是一個 shape為[2,3,3,1] 的tensor。
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import tensorflow as tf
Input = tf.Variable(tf.random_normal([2, 3, 3, 5]))
Filter = tf.Variable(tf.random_normal([2, 2, 5, 1]))
conv4 = tf.nn.conv2d(Input, Filter, strides=[1, 1, 1, 1], padding='SAME')
with tf.Session() as sess:
# 初始化變數
op_init = tf.global_variables_initializer()
sess.run(op_init)
print(sess.run(conv4))
執行結果:
case4:輸入是4張 3*3 大小的圖片,影象通道數是5,卷積核是 2*2大小,數量是4,步長是[1,1,1,1],padding 設定為“SAME”,最後每張圖片得到4個 3*3的feature map。1張圖最後輸出就是一個 shape為[4,3,3,4] 的tensor。
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import tensorflow as tf
Input = tf.Variable(tf.random_normal([4, 3, 3, 5]))
Filter = tf.Variable(tf.random_normal([2, 2, 5, 4]))
conv5 = tf.nn.conv2d(Input, Filter, strides=[1, 1, 1, 1], padding='SAME')
with tf.Session() as sess:
# 初始化變數
op_init = tf.global_variables_initializer()
sess.run(op_init)
print(sess.run(conv5))
執行結果為:
[[[[ 4.1313605 -4.5319715 4.3040133 -0.13911057]
[-10.3389225 -2.2463617 -3.033617 2.0877244 ]
[ -2.3154557 5.4515543 -4.647153 -3.0869713 ]]
[[ 3.0420232 -0.87613493 4.6381464 0.90558195]
[ 3.7932742 -2.4520369 -0.7195463 4.9921722 ]
[ 1.5384624 0.11533099 -2.408429 5.3733883 ]]
[[ 0.45401835 -2.7483764 0.07065094 0.443908 ]
[ -6.8117185 2.2884533 -5.3677235 1.5834118 ]
[ -1.2048031 1.848783 -1.6733127 -1.7782905 ]]]
[[[ 0.5970616 0.77205956 -3.116805 -2.0577238 ]
[ -1.9527029 0.34587446 -5.7365794 -7.39325 ]
[ 2.9361765 -3.504585 4.057106 1.7304868 ]]
[[ 7.262891 -2.492215 4.7126684 1.7249267 ]
[ -7.4239445 2.9972248 4.2400084 0.9729483 ]
[ 1.9529393 3.4738922 0.24985534 -2.922786 ]]
[[ -0.3828507 0.49657637 0.47466695 0.8126482 ]
[ -0.3671472 -0.67494106 0.46129555 -0.9638461 ]
[ 1.6664319 0.885748 0.31974202 -1.9972321 ]]]
[[[ -1.7906806 -1.0376648 1.7166338 -1.0403266 ]
[ 2.5069232 4.0962963 -1.6884253 -0.23492575]
[ -0.3881849 -2.4799151 3.8491216 -2.5564618 ]]
[[ -3.5605621 1.6819414 -4.8645535 -1.9251393 ]
[ -1.8077193 -5.6664057 4.750779 1.2238139 ]
[ 2.346087 -6.2254734 5.1787786 4.3055882 ]]
[[ 0.21012396 -2.145918 1.358845 -0.25860584]
[ -2.3559752 4.263964 -0.51586103 -6.9163604 ]
[ 5.504827 4.0707703 0.11547554 -7.818963 ]]]
[[[ 3.716207 0.63655686 -1.2434999 -4.472955 ]
[ 2.067041 2.0510454 4.2357826 -4.159449 ]
[ 0.63638914 1.9863756 0.42491168 0.13413942]]
[[ 2.5624473 5.041215 3.689196 -1.1119944 ]
[ 4.470556 6.4554853 -7.154124 3.396954 ]
[ -4.9033055 -4.636659 3.7072423 -0.6310719 ]]
[[ -0.92771626 5.868825 -3.1153893 -3.9012384 ]
[ 3.4494157 -2.7036839 5.6135087 3.6358144 ]
[ 1.0283424 -4.246024 0.7325883 0.4899721 ]]]]