1. 程式人生 > >tf.nn.conv2d_transpose 例項 及 解析

tf.nn.conv2d_transpose 例項 及 解析

看原始碼可見conv2d_transpose實際就是計算conv2d_backprop_input

// Consider a case where the input is 3x3 and the filter is 2x1:
//
// INPUT = [ A  B  C ]
//         [ D  E  F ]
//         [ G  H  I ]
//
// where each "A", "B", etc is batch x in_depth
//
// FILTER = [ X  Y ]
//
// where both "X" and "Y" are in_depth x out_depth
//
// With VALID padding, the output is 3x2:
// // OUTPUT = [ a b ] // [ c d ] // [ e f ] // // where each "a", "b", etc is batch x out_depth // // So we have: // // a = A * X + B * Y // b = B * X + C * Y // c = D * X + E * Y // d = E * X + F * Y // e = G * X + H * Y // f = H * X + I * Y // // So when we have backprops for the outputs (we denote them by
// a', b', ... ): // // The backprops for the input are: // 因為A'和a'都是backprops,所以這裡X^t理解為W // A' = a' * X^t // B' = a' * Y^t + b' * X^t // C' = b' * Y^t // ... // // This is essentially computing a 2d conv of // // INPUT2 = [ 0 a' b' 0 ] // [ 0 c' d' 0 ] // [ 0 e' f' 0 ] // and // // FILTER2 = [ Y^t X^t ]#注意這裡是YX不是XY
//INPUT2在下面的例子裡就是 [[0 1 -1 0] [0 2 2 0] [0 1 2 0]] //FILTER2在下面的例子裡就是 [[-1,1]] //而有了INPUT2和FILTER2這時居然就可以計算A'到I'了! //在下面的例子里根據 A' = a' * X^t 計算A'居然正好是A,可以同理到I居然! //也就是INPUT2和FILTER2的卷積就是INPUT [[ 1. -2. 1.] [ 2. 0. -2.] [ 1. 1. -2.]]
import numpy as np
import tensorflow as tf


# [batch, height, width, depth]
x_image = tf.placeholder(tf.float32,shape=[3,2])
x = tf.reshape(x_image,[1,3,2,1])

#Filter: W  [kernel_height, kernel_width, output_depth, input_depth]
W_cpu = np.array([[1,-1]],dtype=np.float32)
W = tf.Variable(W_cpu)
W = tf.reshape(W, [1,2,1,1])

strides=[1, 1, 1, 1]
padding='VALID'

y = tf.nn.conv2d_transpose(x, W, [1,3,3,1],strides, padding)

x_data = np.array([[1,-1],[2,2],[1,2]],dtype=np.float32)
with tf.Session() as sess:
    init = tf.initialize_all_variables()
    sess.run(init)

    x = (sess.run(x, feed_dict={x_image: x_data}))
    W = (sess.run(W, feed_dict={x_image: x_data}))
    y = (sess.run(y, feed_dict={x_image: x_data}))

    print "The shape of x:\t", x.shape, ",\t and the x.reshape(3,2) is :"
    print x.reshape(3,2)
    print ""

    print "The shape of x:\t", W.shape, ",\t and the W.reshape(1,2) is :"
    print W.reshape(1,2)
    print ""

    print "The shape of y:\t", y.shape, ",\t and the y.reshape(3,3) is :"
    print y.reshape(3,3)
    print ""

最後一個print是
[[ 1. -2. 1.]
[ 2. 0. -2.]
[ 1. 1. -2.]]
也就是上面解釋的最後那塊
也就是我們知道了自己手算檢驗這個函式計算結果方法