1. 程式人生 > >自然場景文字處理論文整理(1)Spatial Transformer Networks

自然場景文字處理論文整理(1)Spatial Transformer Networks

paper:Spatial Transformer Networks
在Theano框架中,STN演算法已經被封裝成API,可以直接呼叫。tensorflow實現見文章最後。
1、空間變換器的結構:
這裡寫圖片描述
這是一個可微分的模組,它在單個前向傳遞期間將空間變換應用於要素圖,其中變換以特定輸入為條件,從而生成單個輸出要素圖。對於多通道輸入,對每個通道應用相同的扭曲。為簡單起見,在本節中我們考慮每個變換器的單個變換和單個輸出,但是我們可以推廣到多個變換,如實驗中所示。空間變換器機制分為三個部分,如上圖所示。按計算順序,首先定位網路採用輸入特徵對映,並通過若干隱藏層輸出空間變換的引數應該應用於要素圖 - 這給出了輸入的條件轉換。然後,使用預測的變換引數來建立取樣網格,該取樣網格是應該對輸入圖進行取樣以產生變換輸出的一組點。這是由網格生成器完成的,如Sect。最後,將特徵圖和取樣網格作為取樣器的輸入,生成從網格點輸入取樣的輸出圖。 這三個部件的組合形成空間變換器。

原理上,一個feature map上學一個變換引數出來,這個引數作用到feature map上得到一個取樣器G,然後用G對輸入的feature map做取樣,就得到了輸出V,也就是V上的每個點由U上的點進行取樣得到。

2、一個層加進來之後,應該用多少個這樣的層呢?這個層和其他層用什麼樣的連線方式呢?
1. 每個channel可以有自己單獨對應的一個stn引數,這樣可以用不同的spatial transform來描述feature的空間變換
2. 一個channel可以同時連多個stn,用來處理圖片中有多個目標時的情況
這個第2點看上去是個比較糟糕的情況,一個畫面中如果有多個目標,每個目標的形變可能都不一樣,那麼用同樣的stn對全圖做變化是不太合理的,但是在並不知道圖裡有多少個目標的情況下,只能設定一個固定的值。

3、連線方式
實驗裡每隔幾層conv就放一個stn,這樣就是在feature上做spatial transform了,做了視覺化之後可以發現stn不止做空間變化,還有crop的效果,類似attention,所以在運算上也會有些加速。
這裡寫圖片描述

4、stn並行
多個stn並行作用在同一個feature map上的效果,從結果上看stn變多了對結果還是有幫助的,這個原因解釋為更多的stn可以更好的對不同part做spatial transform並且關注不同的區域特徵

5、定位網路函式f loc()可以採用任何形式,例如全連線網路或卷積網路,但應包括最終迴歸層以產生 變換引數θ

6、

STN的tensorflow實現
 import tensorflow as tf
from scipy import ndimage
import numpy as np
import matplotlib.pyplot as plt
import cv2

def transformer(U,theta,out_size,name='SpatialTransformer',**kwargs):
    print('begin-transformer')
    #tf.stack()矩陣拼接函式
    #得到拼接後的矩陣shape,並全部置為1

    #tf.expand_dims()在第axis位置增加一個維度,這裡axis=1
    # 't' is a tensor of shape [2]
    #shape(expand_dims(t, 0)) ==> [1, 2]
    #shape(expand_dims(t, 1)) ==> [2, 1]
    #shape(expand_dims(t, -1)) ==> [2, 1]
    # 't2' is a tensor of shape [2, 3, 5]
    #shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5]
    #shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5]
    #shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1]
    #n_repeats是什麼?
    def _repeat(x,n_repeats):
        with tf.variable_scope('_repeat'):
            rep = tf.transpose(tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])),1),[1,0])
            rep = tf.cast(rep,'int32')
            #將x reshape為n行1列後,再與rep做計算。
            x = tf.matmul(tf.reshape(x,(-1,1)),rep)
            return tf.reshape(x,[-1])
    #插值函式
    def _interpolate(im,x,y,out_size):
        with tf.variable_scope('_interpolate'):
            num_batch = tf.shape(im)[0]
            height = tf.shape(im)[1]
            width = tf.shape(im)[2]
            channels = tf.shape(im)[3]

            x = tf.cast(x,'float32')
            y = tf.cast(y,'float32')
            height_f = tf.cast(height,'float32')
            width_f = tf.cast(height,'float32')
            out_height = out_size[0]
            out_width = out_size[1]
            zero = tf.zeros([],dtype='int32')
            max_y = tf.cast(tf.shape(im)[1] - 1,'int32')
            max_x = tf.cast(tf.shape(im)[2] - 1,'int32')

            x = (x + 1.0)*(width_f) / 2.0
            x = (x + 1.0)*(height_f) / 2.0

            x0 = tf.cast(tf.floor(x),'int32')
            x1 = x0 + 1
            y0 = tf.cast(tf.floor(y),'int32')
            y1 = y0 + 1

            x0 = tf.clip_by_value(x0, zero, max_x)
            x1 = tf.clip_by_value(x1, zero, max_x)
            y0 = tf.clip_by_value(y0, zero, max_y)
            y1 = tf.clip_by_value(y1, zero, max_y)
            dim2 = width
            dim1 = width*height
            base = _repeat(tf.range(num_batch)*dim1, out_height*out_width)
            base_y0 = base + y0*dim2
            base_y1 = base + y1*dim2
            idx_a = base_y0 + x0
            idx_b = base_y1 + x0
            idx_c = base_y0 + x1
            idx_d = base_y1 + x1

            # use indices to lookup pixels in the flat image and restore
            # channels dim
            im_flat = tf.reshape(im, tf.stack([-1, channels]))
            im_flat = tf.cast(im_flat, 'float32')
            Ia = tf.gather(im_flat, idx_a)
            Ib = tf.gather(im_flat, idx_b)
            Ic = tf.gather(im_flat, idx_c)
            Id = tf.gather(im_flat, idx_d)

            # and finally calculate interpolated values
            x0_f = tf.cast(x0, 'float32')
            x1_f = tf.cast(x1, 'float32')
            y0_f = tf.cast(y0, 'float32')
            y1_f = tf.cast(y1, 'float32')
            wa = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1)
            wb = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1)
            wc = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1)
            wd = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1)
            output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
            return output
    #生成網格矩陣
    def _meshgrid(height,width):
        print('begin--meshgrid')
        with tf.variable_scope('_meshgrid'):

            x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])),
                            tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0]))
            print('meshgrid_x_t_ok')
            y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1),
                            tf.ones(shape=tf.stack([1, width])))
            print('meshgrid_y_t_ok')
            x_t_flat = tf.reshape(x_t, (1, -1))
            y_t_flat = tf.reshape(y_t, (1, -1))
            print('meshgrid_flat_t_ok')
            ones = tf.ones_like(x_t_flat)
            print('meshgrid_ones_ok')
            print(x_t_flat)
            print(y_t_flat)
            print(ones)

            grid = tf.concat( [x_t_flat, y_t_flat, ones],0)
            print ('over_meshgrid')
            return grid
    #映射回去,即轉換
    def _transform(theta,input_dim,out_size):
        print('_transform')

        with tf.variable_scope('_transform'):
            num_batch = tf.shape(input_dim)[0]
            height = tf.shape(input_dim)[1]
            width = tf.shape(input_dim)[2]
            num_channels = tf.shape(input_dim)[3]
            theta = tf.reshape(theta, (-1, 2, 3))
            theta = tf.cast(theta, 'float32')

            # grid of (x_t, y_t, 1), eq (1) in ref [1]
            height_f = tf.cast(height, 'float32')
            width_f = tf.cast(width, 'float32')
            out_height = out_size[0]
            out_width = out_size[1]
            grid = _meshgrid(out_height, out_width)
            grid = tf.expand_dims(grid, 0)
            grid = tf.reshape(grid, [-1])
            grid = tf.tile(grid, tf.stack([num_batch]))
            grid = tf.reshape(grid, tf.stack([num_batch, 3, -1]))
            #tf.batch_matrix_diag
            # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
            print('begin--batch--matmul')
            T_g = tf.matmul(theta, grid)
            x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])
            y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
            x_s_flat = tf.reshape(x_s, [-1])
            y_s_flat = tf.reshape(y_s, [-1])

            input_transformed = _interpolate(
                input_dim, x_s_flat, y_s_flat,
                out_size)

            output = tf.reshape(
                input_transformed, tf.stack([num_batch, out_height, out_width, num_channels]))
            print('over_transformer')
            return output

    with tf.variable_scope(name):
        output = _transform(theta,U,out_size)
        return output

def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):

    with tf.variable_scope(name):
        num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
        indices = [[i]*num_transforms for i in xrange(num_batch)]
        input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
        return transformer(input_repeated, thetas, out_size)



im=ndimage.imread('cat1.jpg')
im=im/255.
im=im.reshape(1,1200,1600,3)
im=im.astype('float32')
print('img-over')


out_size=(600,800)
batch=np.append(im,im,axis=0)
batch=np.append(batch,im,axis=0)
num_batch=3

x=tf.placeholder(tf.float32,[None,1200,1600,3])
x=tf.cast(batch,'float32')
print('begin---')


with tf.variable_scope('spatial_transformer_0'):
    n_fc=6
    w_fc1=tf.Variable(tf.Variable(tf.zeros([1200*1600*3,n_fc]),name='W_fc1'))
    initial=np.array([[0.5,0,0],[0,0.5,0]])
    initial=initial.astype('float32')
    initial=initial.flatten()

    b_fc1=tf.Variable(initial_value=initial,name='b_fc1')

    h_fc1=tf.matmul(tf.zeros([num_batch,1200*1600*3]),w_fc1)+b_fc1
    print(x,h_fc1,out_size)
    h_trans=transformer(x,h_fc1,out_size)


sess=tf.Session()
sess.run(tf.global_variables_initializer())
y=sess.run(h_trans,feed_dict={x:batch})
plt.imshow(y[0])
plt.show()