1. 程式人生 > >tf.nn.conv2d()函式詳解

tf.nn.conv2d()函式詳解

二維卷積:
inputs:[n, h, w, c]的tensor
filter:[h, w, c, out_size]的tensor

alist=[[[[1,1,1],[2,2,2],[3,3,3]],[[4,4,4],[5,5,5],[6,6,6]]],[[[7,7,7],[8,8,8],[9,9,9]],[[10,10,10],[11,11,11],[12,12,12]]]] #2,2,3,3-n,c,h,w
kenel=(np.asarray(alist)*2).tolist()
print(kenel)
inputs=tf.constant(alist,dtype=tf.float32)
kenel=tf.constant(kenel,dtype=tf.float32)
inputs=tf.transpose(inputs,[0,2,3,1]) #n,h,w,c
kenel=tf.transpose(kenel,[0,2,3,1]) #n,h,w,c
kenel=tf.transpose(kenel,[1,2,3,0]) #h,w,c,n
out=tf.nn.conv2d(inputs,kenel,strides=[1,1,1,1],padding='SAME')
print(out.get_shape())
with tf.Session() as sess:
  print(sess.run(tf.transpose(out,[0,3,1,2])))

輸出:

[[[[2, 2, 2], [4, 4, 4], [6, 6, 6]], [[8, 8, 8], [10, 10, 10], [12, 12, 12]]], [[[14, 14, 14], [16, 16, 16], [18, 18, 18]], [[20, 20, 20], [22, 22, 22], [24, 24, 24]]]]
(2, 3, 3, 2)
[[[[ 232.  348.  232.]
   [ 364.  546.  364.]
   [ 232.  348.  232.]]

  [[ 520.  780.  520.]
   [ 868. 1302.  868.]
   [ 616.  924.  616.]]]


 [[[ 616.  924.  616.]
   [ 868. 1302.  868.]
   [ 520.  780.  520.]]

  [[1480. 2220. 1480.]
   [2236. 3354. 2236.]
   [1480. 2220. 1480.]]]]

詳解:
inputs=[[c11,c12],[c21,c22]],其中每個cij為hw
filter=[[f11,f12],[f21,f22]],其中每個fij為核的h
w
out=[[o11,o12],[o21,o22]],其中每個oij為h*w
o11=conv(c11,f11)+conv(c12,f12)
o12=conv(c11,f21)+conv(c12,f22)
o21=conv(c21,f11)+conv(c22,f12)
o22=conv(c21,f21)+conv(c22,f22)