1. 程式人生 > >MACE原始碼解析【ARM卷積篇(二)】1*1卷積實現

MACE原始碼解析【ARM卷積篇(二)】1*1卷積實現

前言

本文來解析一下MACE中ARM程式碼的1*1卷積的實現。1*1卷積在CNN中是比較特殊的一種操作,不再強調領域操作,一般用到1*1卷積有以下幾種情況(相互之間不獨立)
1.單純的加強非線性對映,不強調領域CNN的特徵提取功能
2.bottleneck結構中進行特徵圖數量的改變
3.depthWise 卷積中組成部分

除了以上三點外其他情況歡迎補充

本文涉及的原始碼檔案:

mace/mace/kernels/arm/conv_2d_neon_1x1.cc
mace/mace/kernels/gemm.cc

從卷積到矩陣乘法

// mace/mace/kernels/arm/conv_2d_neon_1x1.cc
#include "mace/kernels/arm/conv_2d_neon.h" #include "mace/kernels/gemm.h" namespace mace { namespace kernels { void Conv2dNeonK1x1S1(const float *input, const float *filter, const index_t batch, const index_t height, const
index_t width, const index_t in_channels, const index_t out_channels, float *output) { for (index_t b = 0; b < batch; ++b) { Gemm(filter, input + b * in_channels * height * width, 1, out_channels, in_channels, height * width, output + b * out_channels * height * width); } } } // namespace kernels
} // namespace mace

MACE中1*1卷積的程式碼如上,可以看到其實就是在每一個batch中呼叫了gemm矩陣乘法運算。這節簡單說明卷積操作是如何變成矩陣乘法的。假設輸入通道數為C1,輸出通道數為C2。則一般卷積核引數為C1xC2xkhxkw,因此卷積核大小為1*1時,卷積核就從四維變成了兩維矩陣K,大小為C1*C2。在單batch下,假設上一次輸入資料大小為 C1*H*W,把它reshape成一個C1*(H*W)的矩陣F,這樣多通道分別卷積再求和的過程就可以用這兩個矩陣乘積來表示:

Z=KtF
得到了大小為C2*(H*W)的矩陣Z。其實就是單通道的卷積運算退化成了一個矩陣和一個標量的點乘運算了。下圖舉了一個C1=2,C2=3,輸入和輸出特徵圖大小為2*3(1*6、3*2也一樣)的例子。
這裡寫圖片描述
矩陣乘法做完後,就完成了單batch的1*1卷積運算。I0、I1f分別表示2通道的輸入資料,在這裡一個通道w*h個數據被拉成了一行。原始碼中沒有reshape函式?因為記憶體排布並沒有變,所以其實不需要額外的操作。

gemm的實現

那麼轉而來看gemm的實現。

/**
 * Gemm does fast matrix multiplications with batch.
 * It is optimized for arm64-v8 and armeabi-v7a using neon.
 *
 * We adopt two-level tiling to make better use of l1 cache and register.
 * For register tiling, function like GemmXYZ computes gemm for
 * matrix[X, Y] * matrix[Y, Z] with all data being able to fit in register.
 * For cache tiling, we try to compute one block of multiplication with
 * two input matrices and one output matrix fit in l1 cache.
 */

原始碼中開始的註釋如是說。為了更好的優化,MACE應用了矩陣分塊乘法,所以看這部分程式碼前建議先停下來複習一下矩陣分塊乘法的公式。
MACE把大矩陣運算分為兩級的矩陣分塊乘法。第一級的實現名字都是GemmXYZ這種形式,表示大小為[X,Y]和[Y,Z]的矩陣相乘,主要的NEON優化也是在這些函式中。這一級的矩陣計算大小都很小,最大也就Gemm688,所以大部分情況下變數都可以保持在暫存器上,避免暫存器變數溢位到棧上帶來的時間開銷。這一級的分塊矩陣乘法運算稱為register tiling
第二級優化則是把若干register tiling組成一個block,保證一個block內的記憶體需求(2個矩陣輸入+1個矩陣輸出)不會超出L1 cache的大小,提高cache命中率。稱為cache tiling。MACE為了記憶體搬運優化做了兩級的分塊矩陣乘法。

register tiling

#define MACE_GEMM_PART_CAL_8(RC, RA, RAN)                      \
  c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RA), 0);   \
  c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RA), 1);   \
  c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RA), 0);  \
  c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RA), 1);  \
  c##RC = vmlaq_lane_f32(c##RC, b4, vget_low_f32(a##RAN), 0);  \
  c##RC = vmlaq_lane_f32(c##RC, b5, vget_low_f32(a##RAN), 1);  \
  c##RC = vmlaq_lane_f32(c##RC, b6, vget_high_f32(a##RAN), 0); \
  c##RC = vmlaq_lane_f32(c##RC, b7, vget_high_f32(a##RAN), 1);

#define MACE_GEMM_PART_CAL_4(RC)                              \
  c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RC), 0);  \
  c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RC), 1);  \
  c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RC), 0); \
  c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RC), 1);

子矩陣運算關鍵就是這兩個巨集,分別為8(4)個浮點向量和8(4)個標量的累乘和,,也就是我們矩陣運算中的基本操作。MACE_GEMM_PART_CAL_4(RC) 的一次呼叫實現的是1*4(A)和4*4(B)矩陣的乘法。

inline void Gemm144(const float *a_ptr,
                    const float *b_ptr,
                    const index_t stride_a,
                    const index_t stride_b,
                    const index_t stride_c,
                    float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
  MACE_UNUSED(stride_a);
  MACE_UNUSED(stride_c);
  float32x4_t a0;
  float32x4_t b0, b1, b2, b3;
  float32x4_t c0;

  a0 = vld1q_f32(a_ptr);

  b0 = vld1q_f32(b_ptr);
  b1 = vld1q_f32(b_ptr + 1 * stride_b);
  b2 = vld1q_f32(b_ptr + 2 * stride_b);
  b3 = vld1q_f32(b_ptr + 3 * stride_b);

  c0 = vld1q_f32(c_ptr);

  MACE_GEMM_PART_CAL_4(0);

  vst1q_f32(c_ptr, c0);
#else
  GemmBlock(a_ptr, b_ptr, 1, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}

Gemm144為例,輸入矩陣A,B分別可以裝載到1個和4個1*4的浮點向量中去。再通過乘累加指令把計算結果存入1*4的結果向量中。而類似Gemm884這樣的函式,只不過是A矩陣每行多取一個向量。
所以在使用MACE_GEMM_PART_CAL_8計算時需要多2個引數,這兩個引數組成A矩陣的一行。呼叫程式碼長成這樣:

  MACE_GEMM_PART_CAL_8(0, 0, 1);
  MACE_GEMM_PART_CAL_8(1, 2, 3);
  MACE_GEMM_PART_CAL_8(2, 4, 5);
  MACE_GEMM_PART_CAL_8(3, 6, 7);
  MACE_GEMM_PART_CAL_8(4, 8, 9);
  MACE_GEMM_PART_CAL_8(5, 10, 11);
  MACE_GEMM_PART_CAL_8(6, 12, 13);
  MACE_GEMM_PART_CAL_8(7, 14, 15);

第一級的矩陣乘法就是這一系列GemmXYZ組成,而他們的呼叫則組成了第二級,繼續向下看。

cache tiling

這一部分的主體在GemmTileGemm這兩個函式上。畢竟是工程程式碼,需要對邊界進行處理,對不同編譯和裝置環境進行優化。所以程式碼顯得比較龐雜。為了理清邏輯我把aarch64clang 巨集控制的部分程式碼刪除、並暫時把邊界處理的程式碼也給刪掉,現在程式碼看上去是這樣的:

GemmTile(const float *A,
                     const float *B,
                     const index_t height,
                     const index_t K,
                     const index_t width,
                     const index_t stride_a,
                     const index_t stride_b,
                     const index_t stride_c,
                     float *C) {
  index_t h = 0;
  index_t w = 0;
  index_t k = 0;
  int reg_height_tile = 8;
  int reg_K_tile = 8;

  for (h = 0; h < height - reg_height_tile + 1; h += reg_height_tile) {
    for (k = 0; k < K - reg_K_tile + 1; k += reg_K_tile) {
      const float *a_ptr = A + (h * stride_a + k);
      for (w = 0; w + 3 < width; w += 4) {
        const float *b_ptr = B + (k * stride_b + w);
        float *c_ptr = C + (h * stride_c + w);
        Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
      }
    }
  }
}

第一級的矩陣運算放在Gemm884中,此時可以把Gemm884看做單個元素看待。這樣這裡的三層迴圈就和普通的矩陣乘法一致了(回憶下分塊矩陣乘法的公式,其實就是一個遞迴的過程)。
我們再把邊界處理的程式碼加上去

inline void GemmTile(const float *A,
                     const float *B,
                     const index_t height,
                     const index_t K,
                     const index_t width,
                     const index_t stride_a,
                     const index_t stride_b,
                     const index_t stride_c,
                     float *C) {
  index_t h = 0;
  index_t w = 0;
  index_t k = 0;
  int reg_height_tile = 6;
  int reg_K_tile = 4;

  for (h = 0; h < height - reg_height_tile + 1; h += reg_height_tile) {
    for (k = 0; k < K - reg_K_tile + 1; k += reg_K_tile) {
      const float *a_ptr = A + (h * stride_a + k);
      for (w = 0; w + 3 < width; w += 4) {
        const float *b_ptr = B + (k * stride_b + w);
        float *c_ptr = C + (h * stride_c + w);
        Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
      }
      if (w < width) {
          const float *b_ptr = B + (k * stride_b + w);
          float *c_ptr = C + (h * stride_c + w);
          GemmBlock(a_ptr, b_ptr, reg_height_tile, reg_K_tile, width - w,
              stride_a, stride_b, stride_c, c_ptr);
      }
    }
    if (k < K) {
        const float *a_ptr = A + (h * stride_a + k);
        const float *b_ptr = B + k * stride_b;
        float *c_ptr = C + h * stride_c;
        GemmBlock(a_ptr, b_ptr, reg_height_tile, K - k, width, stride_a, stride_b,
            stride_c, c_ptr);
    }
  }
  if (h < height) {
      index_t remain_h = height - h;
      for (k = 0; k < K - reg_K_tile; k += reg_K_tile) {
          const float *a_ptr = A + (h * stride_a + k);
          index_t w;
          for (w = 0; w + 3 < width; w += 4) {
              const float *b_ptr = B + (k * stride_b + w);
              float *c_ptr = C + (h * stride_c + w);
              GemmX44(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr, remain_h);
          }
          if (w < width) {
              const float *b_ptr = B + (k * stride_b + w);
              float *c_ptr = C + (h * stride_c + w);
              GemmBlock(a_ptr, b_ptr, remain_h, reg_K_tile, width - w, stride_a,
                  stride_b, stride_c, c_ptr);
          }
      }
      if (k < K) {
          const float *a_ptr = A + (h * stride_a + k);
          const float *b_ptr = B + k * stride_b;
          float *c_ptr = C + h * stride_c;
          GemmBlock(a_ptr, b_ptr, remain_h, K - k, width, stride_a, stride_b,
              stride_c, c_ptr);
      }
  }
}

對比一下可以看到一個block把3個維度上不足步長的部分用GemmBlock計算了。aarch64clang巨集包中的程式碼,內嵌了NEON的彙編程式碼,可以更好的安排指令排布以及暫存器的使用,可參考GemmXYZ解讀,不贅述了。

Gemm

我們至下而上的終於講到了矩陣乘法最上層介面函式。和GemmTile函式一樣先去掉細枝末節:

// A: height x K, B: K x width, C: height x width
void Gemm(const float *A,
    const float *B,
    const index_t batch,
    const index_t height,
    const index_t K,
    const index_t width,
    float *C,
    const bool transpose_a,
    const bool transpose_b) {
    memset(C, 0, sizeof(float)* batch * height * width);

    // It is better to use large block size if it fits for fast cache.
    // Assume l1 cache size is 32k, we load three blocks at a time (A, B, C),
    // the block size should be sqrt(32k / sizeof(T) / 3).
    // As number of input channels of convolution is normally power of 2, and
    // we have not optimized tiling remains, we use the following magic number
    const index_t block_size = 64;
    const index_t block_tile_height = RoundUpDiv(height, block_size);
    const index_t block_tile_width = RoundUpDiv(width, block_size);
    const index_t block_tile_k = RoundUpDiv(K, block_size);
    const index_t block_tile[3] = { block_tile_height, block_tile_width,
        block_tile_k };
    const index_t remain_height = height % block_size;
    const index_t remain_width = width % block_size;
    const index_t remain_k = K % block_size;
    const index_t remain[3] = { remain_height, remain_width, remain_k };

#pragma omp parallel for collapse(3)
    for (index_t n = 0; n < batch; ++n) {
        for (index_t bh = 0; bh < block_tile[0]; ++bh) {
            for (index_t bw = 0; bw < block_tile[1]; ++bw) {
                const float *a_base = A + n * height * K;
                const float *b_base = B + n * K * width;
                float *c_base = C + n * height * width;

                const index_t ih_begin = bh * block_size;
                const index_t ih_end =
                    bh * block_size +
                    (bh == block_tile[0] - 1 && remain[0] > 0 ? remain[0] : block_size);
                const index_t iw_begin = bw * block_size;
                const index_t iw_end =
                    bw * block_size +
                    (bw == block_tile[1] - 1 && remain[1] > 0 ? remain[1] : block_size);

                for (index_t bk = 0; bk < block_tile[2]; ++bk) {
                    const index_t ik_begin = bk * block_size;
                    const index_t ik_end =
                        bk * block_size + (bk == block_tile[2] - 1 && remain[2] > 0
                        ? remain[2]
                        : block_size);

                    Tensor trans_a;
                    Tensor trans_b;
                    const float *real_a = nullptr;
                    const float *real_b = nullptr;
                    float *real_c = c_base + (ih_begin * width + iw_begin);
                    index_t stride_a;
                    index_t stride_b;
                    index_t stride_c = width;

                    real_a = a_base + (ih_begin * K + ik_begin);
                    stride_a = K;

                    real_b = b_base + (ik_begin * width + iw_begin);
                    stride_b = width;

                    // inside block:
                    // calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k
                    GemmTile(real_a, real_b, ih_end - ih_begin, ik_end - ik_begin,
                        iw_end - iw_begin, stride_a, stride_b, stride_c, real_c);
                }  // bk
            }    // bw
        }      // bh
    }        // n
}

主體依然是矩陣乘法的三層迴圈,只是這次基礎元素從一個register tiel計算變成了一個整個block計算,正如上面說的。這麼做是為了該block涉及的記憶體可以存在L1 cache中,減少計算時的cache miss。預設的block大小為64,此外Gemm把尾部不足64的部分丟給GemmTile去處理了。在迴圈的尾部傳入的block大小是可能不足64的。

總結

  1. 本文介紹了MACE的1*1卷積實現,實際上是呼叫矩陣乘法來完成單個batch內的卷積操作。在其gemm演算法中,使用了兩級矩陣分塊乘法的方案。儘量避免暫存器變數溢位到棧上和cache miss這兩種情況。原始矩陣運算為了計算一個結果對輸入的訪存跨度是很大的(取整行和整列),cache miss和暫存器溢位是必然比較頻繁。
  2. 可以看到實現上不足步長部分,一是會導致邏輯分支,二是沒有NEON優化,所以網路設計的時候不管長寬還是通道數都儘量取4、64的整數倍,會得到更好的計算效能。

相關推薦

MACE原始碼解析ARM()1*1實現

前言 本文來解析一下MACE中ARM程式碼的1*1卷積的實現。1*1卷積在CNN中是比較特殊的一種操作,不再強調領域操作,一般用到1*1卷積有以下幾種情況(相互之間不獨立) 1.單純的加強非線性對映,不強調領域CNN的特徵提取功能 2.bottleneck

Mace原始碼解析 1×NN×11*1

1*7 卷積原始碼解讀 #if defined(MACE_ENABLE_NEON) #include <arm_neon.h> #endif #include "mace/kernels/arm/conv_2d_neon.h" namespac

java之ArrayList初始容量原始碼解析jdk 1.8

ArrayList解析 繼承的類和實現的介面 public class ArrayList<E>extends AbstractList<E>implements List<

redisson分散式鎖redLock原始碼解析未完

一、準備階段 1、原理 一個客戶端需要做如下操作來獲取鎖: 1.獲取當前時間(單位是毫秒) 2.輪流用相同的key和隨機值在N個節點上請求鎖,在這一步裡,客戶端在每個master上請求鎖時會有一個和總的鎖釋放時間相比小的多的超時

別翻了,這文章絕對讓你深刻理解java類的載入以及ClassLoader原始碼分析JVM

目錄 1、什麼是類的載入(類初始化) 2、類的生命週期 3、介面的載入過程 4、解開開篇的面試題 5、理解首次主動使用 6、類載入器 7、關於名稱空間

Unity3D技術文檔翻譯1.1 AssetBundle 工作流

如何 倉庫 ring 資源 string int 觀察 你是 本地 譯者前言:本章是關於從創建到加載,再到使用 AssetBundle 的整個流程的概述。閱讀本章將對 AssetBundle 的工作流程有個簡單而全面的了解。 本章原文所在章節:【Unity Manual】

Python Web框架Django框架第一基礎

界面 博客 make ted 分割 增加 welcom 關系 可選 Django框架第一篇基礎【DjangoMTV模式】 老師博客【www.cnblogs.com/yuanchenqi/articles/6811632.html】 同學博客:http://www.

Unity3D技術文件翻譯1.6 使用 AssetBundle Manager

上一章:【Unity3D技術文件翻譯】第1.5篇 使用 AssetBundles 本章原文所在章節:【Unity Manual】→【Working in Unity】→【Advanced Development】→【AssetBundles】→【AssetBundle Manager】 As

搞定Java併發程式設計1:執行緒的五種可用狀態

本文轉載自牛客網上一網友的回答:概括的解釋下執行緒的幾種可用狀態 第一種狀態:新建(new):新建了一個執行緒物件。例如,Thread thread = new Thread(); 第二種狀態:可執行狀態(Runnable):又叫“就緒狀態”。執行緒新建後,其他執行緒(比如main執行

OpenCV入門教程之 一覽眾山小:OpenCV 2.4.8 or OpenCV 2.4.9元件結構全解析

毛星雲,網路ID「淺墨」,90後,熱愛遊戲開發、遊戲引擎、計算機圖形、實時渲染等技術,就職於騰訊互娛。 微軟最有價值專家 著作《Windows遊戲程式設計之從零開始》、《OpenCV3程式設計入門》 碩士就讀於南京航空航天大學航天學院(2013級碩士研究生),已於2016年三月畢業。本科

JavaSE系列—基礎7——註解基礎知識

目錄 目錄 註解概念 註解,元資料的一種形式,提供了和程式有關但不是程式本身的一部分的資料。添加了註解對程式碼沒有直接的影響。 註解有很多用途,其中包含: 編譯器的資訊——註解可以用來使編譯器檢測錯誤或者忽略警告。 編譯時和部署時處

專案原始碼- 模仿知乎日報吐血高仿知乎日報

對之前的模仿做品進行了改善改善。。。再改善。。。(僅供學習) 多說無益。。。。上圖才是王道: 這個東西越模仿發現他的東西就越多,離上次的模仿時間已經過去好久了,這一版本的介面看似好很多,但還是

JavaSE系列-基礎6——泛型方法

泛型方法是引入自己型別引數的方法。和宣告一個泛型型別是相似的,但是這個型別引數的範圍是在宣告的方法體內。靜態的和非靜態的泛型方法都是允許的,以及泛型類建構函式。 泛型方法的語法包括一個在菱形括號內的一個型別引數,並出現在方法返回型別之前。對於靜態方法來說,型別

JavaSE系列-基礎6——有界型別引數

目錄 有界型別引數 可能有時候要限制在引數化型別中可以用作型別引數的型別。舉個例子來說,一個對數字進行操作的方法可能只希望接受Number或其子類的例項。這是有界型別引數。 宣告一個有界型別引數,列出型別引數的名稱,並且跟隨extends關鍵字,

Vue 原始碼解析 - 例項化 Vue 前(

前言 上一篇文章,大概的講解了Vue例項化前的一些配置,如果沒有看到上一篇,通道在這裡:Vue 原始碼解析 - 例項化 Vue 前(一) 在上一篇的結尾,我說這一篇後著重講一下 defineReactive 這個方法,這個方法,其實就是大家可以在外面看見一些文章對 vue 實現資料雙向繫結原理的過程。

springmvc 的請求流程:(springmvc 的三大元件之一)處理器對映器的配置和通過處理器對映器返回請求鏈的原始碼流程

總結 策略模式,每一種對映器方案都提供了對url 的解析的方案都是不同的 DispatcherServlet 拿著執行鏈去尋找對應的處理器介面卡(HandlerAdapter)為什麼要引入介面卡?因為處理器(Handler)有很多種,DispatcherServlet

React Native 安卓開發----側邊欄的實現DrawerLayoutAndroid以及第三方框架react-native-side-menu的使用第六

前言 做過安卓原生開發的童鞋們應該都做過側邊欄這個東西,而且對於開源框架SlidingMenu和android官方側滑選單DrawerLayout應該都不陌生。 那麼今天也在這裡給大家介紹一下React-Native中的側滑選單DrawerLayoutAnd

影象縮放之一近鄰取樣插值和其速度優化

    unsigned long dst_width=Dst.width;     TARGB32* pDstLine=Dst.pdata;     unsigned long srcy_16=0;     unsigned long for4count=dst_width/4*4;     for (un

為什麼MySQL要用B+樹?聊聊B+樹與硬碟的前世今生宇哥帶你玩轉MySQL 索引()

為什麼MySQL要用B+樹?聊聊B+樹與硬碟的前世今生   在上一節,我們聊到資料庫為了讓我們的查詢加速,通過索引方式對資料進行冗餘並排序,這樣我們在使用時就可以在排好序的資料裡進行快速的二分查詢,使得查詢效率指數提升。但是我在結尾同樣提到一個問題,就是記憶體大小一般是很有限的,不可能把一個表所有的

Head First Servlets and JSP筆記1

http header 多線程 轉換 throw 接收 找到 write ide 1、把Java放到HTML中,JSP應運而生。 2、Servlet本身並沒有main()方法,所以必須要有其他Java程序去調用它,這個Java程序就是Web容器(Container)