1. 程式人生 > >多邊形(Polygon)例項分割論文閱讀與實現

多邊形(Polygon)例項分割論文閱讀與實現

模型大意:

    對於例項分割問題在傳統意義上有比較成功的網路如FCN,其是基於畫素級標籤的編碼解碼器“迴歸”實現,類似迴歸的觀點來解決編碼解碼結構的圖片問題的方式還有前述:Ordinal Depth Supervision for 3D Human Pose Estimation 論文閱讀與實現中的特徵點檢測問題,從迴歸損失與問題契合的觀點來看,後者更為“合意”,因為迴歸的損失與標籤都是對偶的——即球形分佈(損失是球形(l2),而要估計的也是“球形”的中心點)。而FCN及其他基於畫素迴歸可能面臨的一個問題是,損失是球形但是目標可能是多邊形,而球形損失往往會導致目標多邊形的頂點估計“模糊”,所以諸多前幾步估計的結果大致上是下面這樣的:

     上面舉得兩個論文的基本思路就是利用分類替換回歸,直接估計例項邊界的頂點,並用頂點多邊形“逼近”整個例項的邊界,下面以Polygon-RNN++的編碼部分為基礎來做一個Polygon Detection問題。(對於解碼部分也會給出簡單介紹,但自己在實現時發現對相應資料集的效果比較差,應該是自己的實現還有問題,故現階段的實現上用另一種網路結構進行替換,之後原文的強化部分直接略過)

問題表徵見下圖:

要在一個bbox(boundary box)中估計出Polygon。

網路結構如下:

CNN Encoder結構如下:

       基本的結構為圖片特徵進CNN Encoder(一些residual block與conv layer特徵的融合)作為提取的特徵([28x28x128]),之後對於這個特徵進行關於Polygon頂點位置的解碼。基本的解碼結構與RNN類似,解碼得到的是一個28*28+1維OneHot向量,前28*28維對應低畫素解析度的Polygon位置,最後一個維度對應解碼停止的訊號,這時當我們使用這種解析度的label進行softmax分類時就可以滿足上述基本需求,這種總體結構基本上就是原生的Polygon-RNN總體結構(第二個論文連結)

     針對解碼部分是一個比較難的過程,可以類比為文字中的seq2seq翻譯模型,這裡是將影象的特徵“翻譯”成影象的邊界。這時上述編碼的特徵輸出基本上就是context資訊向量。當我們關注解碼端時候,我們要尋找seq2seq中的那個<go>。對於影象Polygon的估計過程第一個解碼點也顯得非常重要(因為”beam search”)。作者用另一個(上文提到的)編碼結構估計第一個Polygon頂點之後展開解碼。(先估計邊再估計點)

解碼部分雖然除錯的有一些問題,但還是簡單地提一下:

先截個圖:

   上面是解碼端的細節,基本結構是用ConvLstm(相當於將Lstm中的weight矩陣變成卷積適應於影象形式輸入)充當解碼神經元,每一步關於前一步(前兩個頂點)進行特徵提取、過濾(alpha_t)之後進行下一步資訊迭代,下面用一段Tensorflow程式碼來嘗試描述這一段做的事情:

lstm_cell = MultiRNNCell([
    Conv2DLSTMCell(
        input_shape = [28, 28, 128],
        output_channels = 64,
        kernel_shape = [3, 3], name="conv2d_lstm_1"),
    Conv2DLSTMCell(
        input_shape = [28, 28, 128],
        output_channels = 16,
        kernel_shape = [3, 3], name="conv2d_lstm_2")
])

lstm_cell = DropoutWrapper(cell=lstm_cell, input_keep_prob=self.keep_prob,
                           output_keep_prob=self.keep_prob,
                           state_keep_prob=self.keep_prob,)
lstm_cell_zero_state = lstm_cell.zero_state(batch_size=self.batch_size, dtype=tf.float32)

state = lstm_cell_zero_state
inputs = batch_skip_features

batch_first_vertex = tf.transpose(tf.reshape(tf.one_hot(self.rnn_step_target[:, 0], depth=28 * 28), [self.batch_size, 28, 28, 1]),
                                  [0, 2, 1, 3])

self.batch_first_vertex = batch_first_vertex

onehot_output_list = [batch_first_vertex]
softmax_output_list = []
with tf.variable_scope("conv2d_call_procedure", reuse=tf.AUTO_REUSE):
    for i in range(self.max_step_num):
        output, new_state = lstm_cell(inputs, state)
        # [batch, DxD+1]
        onehot_output, softmax_output = self.y_onehot_predict_layer(output)
        onehot_output_list.append(tf.reshape(onehot_output[:, :-1], [-1, 28, 28, 1]))
        softmax_output_list.append(softmax_output)

        cell_1_state_h, cell_2_state_h = map(lambda state_tuple: state_tuple.h ,state)

        f1 = tf.layers.dense(inputs=cell_1_state_h, units=128, name="f1_{}".format(i), reuse=tf.AUTO_REUSE)
        f2 = tf.layers.dense(inputs=cell_2_state_h, units=128, name="f2_{}".format(i), reuse=tf.AUTO_REUSE)

        fattn = tf.add_n([batch_skip_features, f1, f2])

        fattn_flatten = tf.reshape(fattn, [-1, self.D * self.D * int(batch_skip_features.get_shape()[-1])])
        fattn_flatten_softmax = tf.nn.softmax(fattn_flatten, axis=-1)
        fattn_softmax = tf.reshape(fattn_flatten_softmax, [-1, self.D, self.D, int(batch_skip_features.get_shape()[-1])])

        # [batch_num, 28, 28, 128]
        Ft = batch_skip_features * fattn_softmax

        Ft_with_previous_y = tf.concat(onehot_output_list[-2:] + [Ft], axis=-1)

        inputs = tf.layers.dense(inputs=Ft_with_previous_y, units=128, name = "fuse_Ft_previous_y_layer_{}".format(i), reuse=tf.AUTO_REUSE)
        state = new_state

      這也是自己第一個版本嘗試復現的解碼結構,收斂有一些問題,在這裡僅僅作為說明原文解碼結構的一個輔助工具,歡迎有興趣的同學來討論。

       從解碼結構與問題結構而言,我們可以將問題抽象成一個關於例項特徵的“凸包”求解過程。對於這個問題Pointer Net是一個很好的解決方案(其在文字上的簡單介紹及應用可參看——Pointer Generator 摘要及其強化學習策略梯度版本初步嘗試),這兩個問題可以說是完全類似的,基本結構相同(對於attention的構建)。

下面是自實現的解碼結構簡單介紹:

         Pointer Net解決凸包問題的一個替代品可以是簡單的Lstm——其是一個解決凸優化的baseline方案。作為這個baseline的增廣可以直接使用用Memory NetWork的方式替代Lstm的方法(採用的是Relational recurrent neural networks 論文程式碼閱讀及實現例子——中的RelationalMemory網路),具體形式見實現。

      由於原文資料集不好下,下面用COCO資料集(由於場景的複雜度及解析度的不穩定性較原文資料集難一些)的person類別進行簡單效果測試。

要在Python中使用COCO資料集要安裝相應API:

        COCO資料集提供的Polygon annotation ( ann["iscrowd"] == 0 )有一個問題,就是Polygon取凸包後分布不均勻(有時候點個數較少),這會使得估計出來的Polygon 取凸包後邊很少(比如說就是簡單的長方形、而且誤差很大——凸包與人邊界很遠),故要考慮對於annotation凸包的Polygon進行插值處理,插值的一個問題是僅能對於函式進行(當一個自變數對應多個因變數時無法插值),故還要考慮二維曲線的分割問題,分割二維曲線使得每一個小段曲線自變數(或因變數)關於步數是單調的(即函式曲線——步數為自變數的單調分割問題),這些在下面的資料處理過程都有實現,下面給出資料處理及資料匯出指令碼:

from PIL import Image, ImageDraw
from pycocotools.coco import COCO
from glob import glob
from collections import defaultdict
import numpy as np
from matplotlib.patches import Polygon
from uuid import uuid1

from scipy.interpolate import interp1d
import cv2
from copy import deepcopy

import pause

def cubic_interp1d(x, y, num = 3):
    f2 = interp1d(x, y, kind='cubic')
    xnew = np.linspace(x[0], x[-1], num=num, endpoint=True)
    ynew = f2(xnew)

    return (xnew, ynew)

def split_seg_array(seg_array, interp = False):
    req_list = []
    y = seg_array[:, 1]

    while True:
        split_idx = 0
        if len(y) <= 1:
            seg_array = seg_array[split_idx:,...]
            if interp and len(seg_array) > 4:
                will_append = seg_array
                x, y = will_append[:, 0], will_append[:, 1]
                x, y = cubic_interp1d(x, y)
                seg_array = np.transpose(np.array([x, y]), [1, 0])
            break

        diff_array = y[:-1] - y[1:]
        diff_sign = np.sign(diff_array)
        if diff_sign[0] == 1:
            split_idx = np.argmin(diff_sign)
        else:
            split_idx = np.argmax(diff_sign)
        if not split_idx:
            seg_array = seg_array[split_idx:,...]
            if interp and len(seg_array) > 4:
                will_append = seg_array
                x, y = will_append[:, 0], will_append[:, 1]
                x, y = cubic_interp1d(x, y)
                seg_array = np.transpose(np.array([x, y]), [1, 0])
            break

        will_append = seg_array[:split_idx + 1,...]
        if interp and len(seg_array) > 4:
            x, y = will_append[:, 0], will_append[:, 1]
            x, y = cubic_interp1d(x, y)
            req_list.append(np.transpose(np.array([x, y]), [1, 0]))
        else:
            req_list.append(will_append)
        seg_array = seg_array[split_idx + 1:,...]
        y = seg_array[:, 1]

    req_list.append(seg_array)
    return req_list

def filter_seg_array(seg_array):
    x = seg_array[:, 0]
    y = seg_array[:, 1]
    new_x = []
    new_y = []
    for idx ,ele in enumerate(x):
        if new_x:
            if ele != new_x[-1]:
                new_x.append(ele)
                new_y.append(y[idx])
        else:
            new_x.append(ele)
            new_y.append(y[idx])
    return np.transpose(np.array([new_x, new_y]), [1, 0])

def gen_crop_img(img_array, bbox_seg_list, show = False, crop_ratio = 0.0, max_step_num = 40, flip = False,
                 gen_convex_hull = True, interp = True):
    img = Image.fromarray(img_array.astype(np.uint8))

    bbox, seg_array = bbox_seg_list[0]

    if gen_convex_hull:
        # filter vertex by convex hull
        seg_array = cv2.convexHull(points=seg_array.astype(np.int32), returnPoints=True)
        seg_array = seg_array.reshape([-1, 2])

    if interp:
        seg_array = filter_seg_array(seg_array)
        xmax, xmin, ymax, ymin = seg_array[:, 0].max(), seg_array[:, 0].min(), seg_array[:, 1].max(), seg_array[:, 1].min()

        timestep_x = np.transpose(np.array([np.array(list(range(len(seg_array)))) ,seg_array[:, 0]]), [1, 0])
        split_timestep_x = split_seg_array(timestep_x)

        req_list = []
        for timestep_x in split_timestep_x:
            timestep, x = timestep_x[:, 0], timestep_x[:, 1]
            y = seg_array[timestep.astype(np.int32), 1]
            if len(y) > 4:
                x, y = cubic_interp1d(x, y, num=len(x) * 2)
            inner_seg_array = np.transpose(np.array([x, y]), [1, 0])
            req_list.append(inner_seg_array)
        seg_array = np.concatenate(req_list, axis=0)
        seg_array[:, 0] = np.clip(seg_array[:, 0], xmin, xmax)
        seg_array[:, 1] = np.clip(seg_array[:, 1], ymin, ymax)

    # sample max_step_num elements in order
    req_indeces = sorted(set(np.linspace(0, len(seg_array) - 1, max_step_num).astype(np.int32).tolist()))[:-1]
    seg_array = seg_array[req_indeces, ...]

    # filter by seg_array length
    if len(seg_array) < max_step_num * 0.6:
        return None

    x, y, w, h = bbox
    if w * h / (img.size[0] * img.size[1]) < crop_ratio:
        return None

    x0, y0, x1, y1 = x - w, y - h, x + w, y + h

    seg_array[:,0] = (seg_array[:,0] - max(0,x0)) / (min(int(x1), img.size[0]) - max(0,int(x0)) + 1) * 224
    if flip:
        seg_array[:,0] = 224 - seg_array[:,0]
        assert np.all(seg_array[:,0] >= 0) and np.all(seg_array[:, 0] <= 224)

    seg_array[:,1] = (seg_array[:,1] - max(0,y0)) / (min(int(y1), img.size[1]) - max(0,int(y0)) + 1) * 224

    if not np.any(seg_array):
        return None

    # filter ele whose seg_array out of square
    filter_val = 224 * 0.8
    if (seg_array[:,0].max() - seg_array[:,0].min() > filter_val) or (seg_array[:,1].max() - seg_array[:,1].min() > filter_val):
        return None

    img_array = np.array(img)
    img = Image.fromarray(img_array[max(0,int(y0)):min(int(y1), img.size[1]),max(0,int(x0)):min(int(x1), img.size[0]),...].astype(np.uint8))
    img = img.resize((224, 224))
    img_array = np.array(img)

    if show:
        draw = ImageDraw.Draw(img)
        for idx ,xy in enumerate(seg_array):
            if idx % 3 == 0:
                draw.text(xy = xy.tolist(), text=str(idx))
        img.save(r"E:\Temp\test\{}.jpg".format(uuid1()))

    return (img_array, seg_array)

# use category as target category
def data_and_annots_gen(dataType = "train2017", req_category = "person", transform_to_bgr2lab = False,
                        only_one = False):
    assert dataType in ["train2017", "val2017"]
    dataDir = r"E:\Temp\annotations_trainval2017"
    annFile='{}/annotations/instances_{}.json'.format(dataDir,dataType)

    coco=COCO(annFile)
    cats = coco.loadCats(coco.getCatIds())
    nms = [cat["name"] for cat in cats]
    assert req_category in nms

    catIds = coco.getCatIds(catNms=[req_category])
    imgIds = coco.getImgIds(catIds=catIds)
    annIds = coco.getAnnIds(imgIds=imgIds, catIds = catIds, iscrowd=None)

    anns = coco.loadAnns(annIds)

    anns_map = map(lambda ann_o: (ann_o["image_id"], [ann_o]),filter(lambda ann: ann["category_id"] in catIds, anns))
    img_annos_dict = defaultdict(list)
    for imgId, ann_list in anns_map:
        img_annos_dict[imgId] += ann_list

    head_dir = r"E:\Temp\{}".format(dataType)
    id_path_dict = dict(map(lambda path: (int(path.split("\\")[-1].replace(".jpg", "")) ,path), glob(head_dir + "\\" + "*")))

    def single_ann_to_bbox_seg_dict(ann):
        assert ann["category_id"] in catIds
        ann_segmentation = ann["segmentation"]
        ann_bbox = ann["bbox"]
        seg = ann_segmentation[0]
        poly = Polygon(np.array(seg).reshape((int(len(seg)/2), 2)))
        xy_array = poly.get_xy()

        return (ann_bbox, xy_array)

    def expand_ann_list(ann_list):
        req_ann_list = []
        for ann in ann_list:
            if ann["iscrowd"] == 0:
                for ann_seg_idx in range(len(ann["segmentation"])):
                    temp_ann = deepcopy(ann)
                    temp_ann["segmentation"] = ann["segmentation"][ann_seg_idx: ann_seg_idx + 1]
                    req_ann_list.append(temp_ann)
        return req_ann_list

    for idx ,(imgId, ann_list) in enumerate(img_annos_dict.items()):
        flip = True if (idx % 2 == 0) else False

        # bbox seg tuple list
        if only_one:
            bbox_seg_list = list(map(single_ann_to_bbox_seg_dict , filter(lambda ann: ann["iscrowd"] == 0 and len(ann["segmentation"]) == 1, ann_list)))
        else:
            bbox_seg_list = list(map(single_ann_to_bbox_seg_dict , expand_ann_list(ann_list)))

        if not bbox_seg_list:
            continue

        img_array = np.asarray(Image.open(id_path_dict[imgId]))

        if img_array.shape[-1] != 3:
            continue

        if transform_to_bgr2lab:
            img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2LAB)

        conclusion = gen_crop_img(img_array, bbox_seg_list, flip = flip)
        if conclusion is None:
            continue

        img_array, seg_array = conclusion
        vertex_mask, target, step = resize_224_to_28(seg_array)

        if flip:
            img_array = cv2.flip(np.array(img_array, dtype=np.uint8), 1)

        # imgId, [224, 224, 3] xy [28, 28]
        yield (imgId, img_array, seg_array, vertex_mask, target, step)

    yield None
    return

def resize_224_to_28(seg_array):
    seg_array = (seg_array / (224 + 1) * 28).astype(np.int32)

    req = np.zeros(shape=[28, 28])

    # [28 * 28 + 1]
    mask = [0] * (28 * 28 + 1)
    idx = 0
    for idx, (x, y) in enumerate(seg_array):
        req[y, x] = 1
        mask[idx] = y * 28 + x
    mask[idx + 1] = 28 * 28 + 1

    # scalar
    step = len(seg_array) + 1
    return req, mask, step


def batch_generator(batch_size = 4 ,dataType = "train2017", req_category = "person",
                    only_one = False):
    assert dataType in ["train2017", "val2017"]
    data_gen = data_and_annots_gen(dataType, req_category=req_category)

    batch_img_array = np.zeros(shape=[batch_size, 224, 224, 3])
    batch_vertex_mask = np.zeros(shape=[batch_size, 28, 28])
    batch_step_mask = np.zeros(shape=[batch_size])
    batch_rnn_step_target = np.zeros(shape=[batch_size, 28 * 28 + 1])

    start_idx = 0
    while True:
        gen_data = data_gen.__next__()
        if gen_data is None:
            break

        imgId, img_array, seg_array, vertex_mask, target, step = gen_data
        batch_img_array[start_idx] = img_array
        batch_vertex_mask[start_idx] = vertex_mask
        batch_step_mask[start_idx] = step
        batch_rnn_step_target[start_idx] = np.array(target)

        start_idx += 1

        if start_idx == batch_size:
            # only yield one sample
            if only_one:
                while True:
                    yield (batch_img_array, batch_vertex_mask, batch_step_mask, batch_rnn_step_target)
            else:
                yield (batch_img_array, batch_vertex_mask, batch_step_mask, batch_rnn_step_target)

            start_idx = 0
            batch_img_array = np.zeros(shape=[batch_size, 224, 224, 3])
            batch_vertex_mask = np.zeros(shape=[batch_size, 28, 28])
            batch_step_mask = np.zeros(shape=[batch_size])
            batch_rnn_step_target = np.zeros(shape=[batch_size, 28 * 28 + 1])

    yield None
    return

        cubic_interp1d為三次插值函式,split_seg_array為曲線單調分割函式。gen_crop_img

中諸多return None的條件對應保證圖片的解析度足夠高及annotation個數足夠多(對於後者的要求主要是考慮到對於類RNN結構當序列長度在樣本層面有巨大差別時對效果有較大影響——實際就是一種樣本不均衡,只不過這種不均衡的影響更大地影響的是序列截止的“訊號”,相應的問題在NLP中實體識別及事件抽取及摘要中較為常見)。

       下面給出網路相關函式,首先是部分工具函式:

Residual block 部分:

import tensorflow as tf

def conv2d(inputs, filters, kernel_size, strides = (2, 2), name = None,
           add_max_pooling = True, ):
    output = tf.layers.conv2d(inputs=inputs, filters = filters, kernel_size=kernel_size,
                              strides=strides, padding="SAME", name = name,
                              )
    if add_max_pooling:
        output = tf.layers.max_pooling2d(inputs=output, strides=strides, padding="SAME",
                                         pool_size=kernel_size, name="{}_max_pool_2d".format(name))
    return tf.nn.leaky_relu(output)

def residual(inputs, out_channels, name = None, is_training = tf.constant(True),
             add_max_pooling = False, pooling_strids=(4, 4)):
    conv2d_output = conv2d(inputs = inputs, filters=out_channels, kernel_size=(3, 3),
                           strides=(1, 1), name ="{}_conv".format(name),
                        add_max_pooling=False)
    identity_output = tf.layers.conv2d(inputs=inputs, filters = out_channels, kernel_size=(1, 1),
                                       strides=(1, 1), padding="SAME", name = "{}_identity".format(name),
                                       )
    output = conv2d_output + identity_output
    if add_max_pooling:
        output = tf.layers.max_pooling2d(inputs=output, strides=pooling_strids, padding="valid",
                                         pool_size=(3, 3), name="{}_max_pool_2d".format(name))

    output = tf.nn.leaky_relu(tf.layers.batch_normalization(output, training=tf.constant(True),
                                                            name="{}_batch_normalizetion".format(name)))

    return tf.nn.dropout(output, keep_prob= 1.0 - tf.cast(is_training, tf.float32) * 0.05)
from sonnet.python.modules import basic
from sonnet.python.modules import layer_norm
from sonnet.python.modules import rnn_core
from sonnet.python.modules.nets import mlp

import tensorflow as tf

class RelationalMemory(rnn_core.RNNCore):
    def __init__(self, mem_slots = 10, head_size = 10, num_heads = 3, num_blocks = 1,
                 forget_bias = 1.0, input_bias = 0.0, gate_style = "unit",
                 attension_mlp_layers = 2, key_size = None, name = "relational_memory"):
        super(RelationalMemory, self).__init__(name="name")
        self._mem_slots = mem_slots

        # multi head size
        self._head_size = head_size
        self._num_heads = num_heads

        self._mem_size = self._head_size * self._num_heads

        if num_blocks < 1:
            raise ValueError("num_blocks must be >= 1, Got: {}.".format(num_blocks))
        self._num_blocks = num_blocks
        self._forget_bias = forget_bias
        self._input_bias = input_bias

        if gate_style not in ["unit", "memory", None]:
            raise ValueError(
                r"gate_style must be one of ['unit', 'memory', None] Got {}".format(gate_style)
            )
        self._gate_style = gate_style
        if attension_mlp_layers < 1:
            raise ValueError("attension_mlp_layers must be >= 1, Got: {}".format(
                attension_mlp_layers
            ))

        self._attention_mlp_layers = attension_mlp_layers
        # this size may be the size compatible with column num of memory
        self._key_size = key_size if key_size else self._head_size
    # init memory matrix
    def initial_state(self, batch_size, trainable = False):
        '''
        # [batch, mem_slots, mem_slots]
        init_state = tf.eye(self._mem_slots, batch_shape=[batch_size])
        if self._mem_size > self._mem_slots:
            difference = self._mem_size - self._mem_slots
            pad = tf.zeros((batch_size, self._mem_slots, difference))
            init_state = tf.concat([init_state, pad], -1)
        elif self._mem_size < self._mem_slots:
            init_state = init_state[:, :, :self._mem_size]
        return init_state
        '''
        init_state = tf.eye(self._mem_slots, self._mem_size, batch_shape=[batch_size])
        return init_state

    def _multihead_attention(self, memory):
        key_size = self._key_size
        value_size = self._head_size

        qkv_size = 2 * key_size + value_size
        total_size = qkv_size * self._num_heads

        qkv = basic.BatchApply(basic.Linear(total_size))(memory)
        qkv = basic.BatchApply(layer_norm.LayerNorm())(qkv)

        mem_slots = memory.get_shape().as_list()[1]

        qkv_reshape = basic.BatchReshape([mem_slots, self._num_heads,
                                          qkv_size])(qkv)
        qkv_transpose = tf.transpose(qkv_reshape, [0, 2, 1, 3])
        q, k, v = tf.split(qkv_transpose, [key_size, key_size, key_size], -1)
        q *= qkv_size ** -0.5
        dot_product = tf.matmul(q, k, transpose_b=True)
        weights = tf.nn.softmax(dot_product)

        output = tf.matmul(weights, v)
        output_transpose = tf.transpose(output, [0, 2, 1, 3])

        new_memory = basic.BatchFlatten(preserve_dims=2)(output_transpose)
        return new_memory

    @property
    def state_size(self):
        return tf.TensorShape([self._mem_slots, self._mem_size])

    @property
    def output_size(self):
        return tf.TensorShape(self._mem_slots * self._mem_size)

    def _calculate_gate_size(self):
        if self._gate_style == "unit":
            return self._mem_size
        elif self._gate_style == "memory":
            return 1
        else:
            return 0

    def _create_gates(self, inputs, memory):
        num_gates = 2 * self._calculate_gate_size()
        memory = tf.tanh(memory)

        # shape 2
        inputs = basic.BatchFlatten()(inputs)
        gate_inputs = basic.BatchApply(basic.Linear(num_gates), n_dims=1)(inputs)
        # shape 3
        gate_inputs = tf.expand_dims(gate_inputs, axis=1)
        gate_memory = basic.BatchApply(basic.Linear(num_gates))(memory)

        # broadcast add to every row of memory
        gates = tf.split(gate_memory + gate_inputs, num_or_size_splits=2, axis=2)
        input_gate, forget_gate = gates

        input_gate = tf.sigmoid(input_gate + self._input_bias)
        forget_gate = tf.sigmoid(forget_gate + self._forget_bias)

        return input_gate, forget_gate

    def _attend_over_memory(self, memory):
        attention_mlp = basic.BatchApply(
            mlp.MLP([self._mem_size] * self._attention_mlp_layers)
        )
        for _ in range(self._num_blocks):
            attended_memory = self._multihead_attention(memory)
            memory = basic.BatchApply(layer_norm.LayerNorm())(
                memory + attended_memory
            )
            memory = basic.BatchApply(layer_norm.LayerNorm())(
                attention_mlp(memory) + memory
            )
        return memory

    def _build(self, inputs, memory, treat_input_as_matrix = False):
        if treat_input_as_matrix:
            inputs = basic.BatchFlatten(preserve_dims=2)(inputs)
            inputs_reshape =basic.BatchApply(
                basic.Linear(self._mem_size), n_dims=2
            )(inputs)
        else:
            inputs = basic.BatchFlatten()(inputs)
            inputs = basic.Linear(self._mem_size)(inputs)
            inputs_reshape = tf.expand_dims(inputs, 1)

        memory_plus_input = tf.concat([memory, inputs_reshape], axis=1)
        next_memory = self._attend_over_memory(memory_plus_input)

        n = inputs_reshape.get_shape().as_list()[1]
        next_memory = next_memory[:,:-n,:]

        if self._gate_style == "unit" or self._gate_style == "memory":
            self._input_gate, self._forget_gate = self._create_gates(
                inputs_reshape, memory
            )
            next_memory = self._input_gate * tf.tanh(next_memory)
            next_memory += self._forget_gate * memory

        output = basic.BatchFlatten()(next_memory)
        return output, next_memory

    @property
    def input_gate(self):
        self._ensure_is_connected()
        return self._input_gate

    @property
    def forget_gate(self):
        self._ensure_is_connected()
        return self._forget_gate


if __name__ == "__main__":
    pass

模型主體程式碼:(由於視訊記憶體的限制這裡僅解碼40步)

import tensorflow as tf
from model.model_utils import residual
from data_preprcess.coco_data_loader_v4 import batch_generator
import os
import random

import numpy as np
from PIL import Image, ImageDraw
import cv2
from uuid import uuid1

from senior_model.reasonNN import RelationalMemory

class PolygonRNN(object):
    def __init__(self, input_height = 224, input_width = 224,
                 channel_num = 3, batch_size = 4, D = 28,
                 max_step_num = 30, dense_dim = 100):
        self.input_img = tf.placeholder(tf.float32, [None, input_height, input_width, channel_num],
                                        name="input_img")
        self.is_training = tf.placeholder(tf.bool, [], name="is_training")
        self.keep_prob = tf.placeholder(tf.float32, [], name="keep_prob")

        # max step mask used in batch_softmax_output
        self.step_mask = tf.placeholder(tf.int32, [None], name="step_mask")

        self.max_step_num = max_step_num
        self.dense_dim = dense_dim

        # will check data input order must in order (in some distance order) to insurance rnn construct
        # self.D * self.D + 1 indicate max step num, element in [0, self.D * self.D + 1)
        self.rnn_step_target = tf.placeholder(tf.int32, [None, max_step_num], name="rnn_step_target")

        self.batch_size = batch_size
        self.D = D

        batch_skip_features_for_decode = self.encoder_layer()

        self.DxD_plus_1_W = tf.Variable(tf.random_normal(
            shape=[28 * 28 * 16, D * D + 1 + 1], mean=0.0, stddev=1.0
        ), name="DxD_plus_1_W")
        self.DxD_plus_1_b = tf.Variable(tf.constant([1.0] *(D * D + 1 + 1)),
                                        name="DxD_plus_1_b")

        logits, targets = self.decoder(batch_skip_features_for_decode)
        self.rnn_opt_construct(logits = logits, targets = targets)

    def conv_res_upsampling(self, input, parent_name = "conv1", filters = 256):
        with tf.variable_scope("{}_upsampling".format(parent_name)):
            conv2d_output = tf.layers.conv2d(inputs=input, filters = filters,
                                             kernel_size = (3, 3),
                                             strides=(1, 1),
                                             padding='same', name="{}_conv2_3x3".format(parent_name))
            bilinear_output = tf.image.resize_bilinear(conv2d_output, size=(112, 112),
                                                       name="{}_bilinear".format(parent_name))
            return bilinear_output

    def encoder_layer(self):
        conv1_output = tf.layers.conv2d(inputs=self.input_img, filters = 64,
                                        kernel_size = (3, 3),
                                        strides=(1, 1),
                                        padding='same', name="conv1_output")
        conv1_output = tf.nn.relu(tf.layers.batch_normalization(inputs=conv1_output, training=self.is_training,
                                                                name="conv1_normal_layer"))
        conv1_output = tf.layers.max_pooling2d(
            inputs=conv1_output, pool_size = (2, 2), strides = (2, 2),
        )
        res1_output = residual(inputs=conv1_output, out_channels = 256, name = "res1_output",
                               is_training = self.is_training,
                               add_max_pooling = True, pooling_strids=(4, 4))

        res2_output = residual(inputs=res1_output, out_channels = 512, name = "res2_output",
                               is_training = self.is_training,
                               add_max_pooling = False)

        res3_output = residual(inputs=res2_output, out_channels = 1024, name = "res3_output",
                               is_training = self.is_training,
                               add_max_pooling = False)

        res4_output = residual(inputs=res3_output, out_channels = 2048, name = "res4_output",
                               is_training = self.is_training,
                               add_max_pooling = False)

        conv1_part = self.conv_res_upsampling(conv1_output, "conv1")
        res1_part = self.conv_res_upsampling(res1_output, "res1")
        res2_part = self.conv_res_upsampling(res2_output, "res2")
        res4_part = self.conv_res_upsampling(res4_output, "res4")

        conv_res_feature_before_concat = [conv1_part, res1_part, res2_part, res4_part]

        conv_res_feature_fused = tf.concat(conv_res_feature_before_concat, axis=-1)
        down_conv2d_1 = tf.layers.conv2d(inputs=conv_res_feature_fused, filters=128, strides=(1, 1), kernel_size=(3, 3),
                                         name="down_conv2d_1", padding="same")
        down_maxpool2d_1 = tf.layers.max_pooling2d(inputs=down_conv2d_1, pool_size = (3, 3), strides = (2, 2),
                                                   name="down_maxpool2d_1", padding="same")
        down_conv2d_2 = tf.layers.conv2d(inputs=down_maxpool2d_1, filters=128, strides=(1, 1), kernel_size=(3, 3),
                                         name="down_conv2d_2", padding="same")
        down_maxpool2d_2 = tf.layers.max_pooling2d(inputs=down_conv2d_2, pool_size = (3, 3), strides = (2, 2),
                                                   name="down_maxpool2d_2", padding="same")
        skip_features = tf.layers.conv2d(inputs=down_maxpool2d_2, filters=128, strides=(1, 1), kernel_size=(3, 3),
                                         name="skip_features", padding="same")

        # [batch_num, 28, 28, 128]
        return skip_features

    def decoder(self, batch_skip_features):
        # batch_skip_features [batch_num, 28, 28, 128]
        batch_skip_features = batch_skip_features[:self.batch_size, ...]

        batch_skip_features = residual(batch_skip_features, out_channels=10, add_max_pooling=False, name="red_reduce")
        dense_output = tf.reshape(tf.layers.dense(tf.reshape(batch_skip_features, [self.batch_size, 28* 28 *10]), units=(self.dense_dim) * self.max_step_num, name="prod_layer"),
                                          [self.batch_size, self.dense_dim, self.max_step_num])

        # [batch, max_step_num, self.dense_dim]
        dense_output_T = tf.transpose(dense_output, [0, 2, 1])

        relationalMemoryCell = RelationalMemory(
                mem_slots = self.max_step_num, head_size = 100, num_heads = 4, num_blocks = 4,
                name="relationalMemoryCell"
            )

        outputs, _ = tf.nn.dynamic_rnn(cell=relationalMemoryCell,
                                           inputs=dense_output_T,
                                           sequence_length=None,
                                           dtype=tf.float32)
        self.batch_softmax_output_T = tf.layers.dense(inputs=tf.concat([outputs, dense_output_T], axis=-1), units=28 * 28 + 1 + 1, name="memory_to_conclusion")
        self.batch_softmax_output = tf.transpose(self.batch_softmax_output_T, [0, 2, 1])

        seq_mask = tf.sequence_mask(lengths=self.step_mask, maxlen= self.max_step_num)

        logits = tf.boolean_mask(self.batch_softmax_output_T, seq_mask)
        targets = tf.one_hot(tf.boolean_mask(self.rnn_step_target, seq_mask), depth=self.D * self.D + 1 + 1)

        return (logits, targets)

    def rnn_opt_construct(self, logits, targets):
        self.rnn_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
            labels=targets,
            logits=logits,
        ))
        self.rnn_opt = tf.train.AdamOptimizer(0.001).minimize(self.rnn_loss)

        # prediction decode prcedure used following variables.
        # [batch, max_step_num]
        self.prediction = tf.cast(tf.argmax(self.batch_softmax_output, axis=1), tf.int32)
        # [batch] max steps
        # used to slice self.prediction
        self.step_num_prediction = tf.argmax(tf.cast(tf.equal(self.prediction ,
                                                              tf.fill(dims = [self.batch_size, self.max_step_num], value= self.D * self.D + 1)),
                                                     tf.int32), axis=-1)

    @staticmethod
    def visualize_prediction(batch_img_array ,prediction, step_num_prediction, draw_text = True, transform_to_bgr2lab = False):
        img_array = batch_img_array[0]
        prediction_f = prediction[0]
        step_num_prediction_f = step_num_prediction[0]
        if step_num_prediction_f == 0:
            step_num_prediction_f = len(prediction_f)
        else:
            #print("step_num_prediction_f : {}".format(step_num_prediction_f))
            pass

        rc_t2_list = list(map(lambda index: list(map(lambda r_or_c: r_or_c / (28) * 224, divmod(index, 28))), prediction_f[:step_num_prediction_f].tolist()))

        if transform_to_bgr2lab:
            img_array = cv2.cvtColor(img_array.astype(np.uint8), cv2.COLOR_BGR2LAB)
        img = Image.fromarray(img_array.astype(np.uint8))

        img_vertex_list = np.asarray(list(map(lambda rc: [rc[1], rc[0]], rc_t2_list)), dtype=np.int32)

        draw = ImageDraw.Draw(img)
        img_vertex_list = cv2.convexHull(points=img_vertex_list, returnPoints=True)
        img_vertex_list = img_vertex_list.reshape([-1, 2])

        for idx, (r, c) in enumerate(rc_t2_list):
            xy = np.array([c, r])
            if draw_text:
                draw.text(xy = xy.tolist(), text=str(idx))
            else:
                draw.ellipse((xy - 2).tolist() + (xy + 2).tolist())

        # draw Convex Hull
        draw.line(xy = list(map(tuple ,img_vertex_list.tolist() + img_vertex_list.tolist()[:1])))
        img.save(r"E:\Temp\valid\{}.jpg".format(uuid1()))

    @staticmethod
    def filter_sample_Sigmoid(Sigmoid_val, topN = 10):
        batch_size = Sigmoid_val.shape[0]
        flatten_sig_list = Sigmoid_val.reshape([batch_size, -1]).tolist()

        def map_choice_val(sig_list, choice_val):
            return [0 if ele > choice_val else ele for ele in sig_list]

        flatten_sig_array = np.array(list(map(lambda sig_list: map_choice_val(sig_list ,random.choice(sorted(sig_list)[::-1][:topN])), flatten_sig_list)))
        return flatten_sig_array.reshape([batch_size, 28, 28])

    @staticmethod
    def train_rnn():
        batch_size = 4
        max_step_num = 40
        polygonRNN_ext = PolygonRNN(batch_size=batch_size, max_step_num = max_step_num)
        saver = tf.train.Saver()

        train_data_gen = batch_generator(dataType="train2017", batch_size=batch_size)
        valid_data_gen = batch_generator(dataType="val2017", batch_size=batch_size)

        step = 0
        epoch = 3

        with tf.Session() as sess:
            if os.path.exists(r"C:\Coding\Python\PolygonRNN_v1\poly_uu{}.meta".format(epoch)):
                saver.restore(sess, save_path=r"C:\Coding\Python\PolygonRNN_v1\poly_uu{}".format(epoch))
                print("load exists")
            else:
                sess.run(tf.global_variables_initializer())
                print("init new")

            while True:
                train_data = train_data_gen.__next__()
                if train_data is None:
                    print("one epoch end")
                    train_data_gen = batch_generator(dataType="train2017", batch_size=batch_size)
                    epoch += 1
                    step = 0
                    train_data = train_data_gen.__next__()

                batch_img_array, batch_vertex_mask, batch_step_mask, batch_rnn_step_target = train_data
                _, train_loss, prediction, step_num_prediction, batch_softmax_output_T= sess.run([polygonRNN_ext.rnn_opt, polygonRNN_ext.rnn_loss,
                                                                                                  polygonRNN_ext.prediction, polygonRNN_ext.step_num_prediction,
                                                                                                  polygonRNN_ext.batch_softmax_output_T,
                                                                                                  ],
                                                                                                 feed_dict={
                                                                                                     polygonRNN_ext.input_img: batch_img_array,
                                                                                                     polygonRNN_ext.step_mask: batch_step_mask,
                                                                                                     polygonRNN_ext.rnn_step_target: batch_rnn_step_target[:, :max_step_num],

                                                                                                     polygonRNN_ext.is_training: True,
                                                                                                     polygonRNN_ext.keep_prob: 0.9
                                                                                                 })
                step += 1
                if step % 50 == 0:
                    # visualize train labels in rgb space
                    PolygonRNN.visualize_prediction(batch_img_array, prediction, step_num_prediction)

                    val_data = valid_data_gen.__next__()
                    if val_data is None:
                        print("val consume end")
                        valid_data_gen = batch_generator(dataType="val2017", batch_size=batch_size)
                        val_data = valid_data_gen.__next__()

                    batch_img_array, batch_vertex_mask, batch_step_mask, batch_rnn_step_target = val_data
                    val_loss, prediction, step_num_prediction = sess.run([polygonRNN_ext.rnn_loss, polygonRNN_ext.prediction, polygonRNN_ext.step_num_prediction,
                                                                          ],
                                                                         feed_dict={
                                                                             polygonRNN_ext.input_img: batch_img_array,
                                                                             polygonRNN_ext.step_mask: batch_step_mask,
                                                                             polygonRNN_ext.rnn_step_target: batch_rnn_step_target[:, :max_step_num],

                                                                             polygonRNN_ext.is_training: True,
                                                                             polygonRNN_ext.keep_prob: 1.0
                                                                         })
                    print("epoch : {} step : {} train_loss: {} val_loss: {}".format(epoch, step, train_loss, val_loss))
                    # visualize val labels in bgr2lab space
                    PolygonRNN.visualize_prediction(batch_img_array, prediction, step_num_prediction, transform_to_bgr2lab=True)

                if step % 500 == 0:
                    saver.save(sess, save_path=r"C:\Coding\Python\PolygonRNN_v1\poly_uu{}".format(epoch))
                    print("model have dumped")

if __name__ == "__main__":
    PolygonRNN.train_rnn()

下面我們看一些估計valid集“凸包”的效果例子:(如果能夠在更多同“背景”的大樣本資料集上進行實驗可能會有更好的效果)