1. 程式人生 > >詳解如何將TensorFlow訓練的模型移植到Android手機

詳解如何將TensorFlow訓練的模型移植到Android手機

前言

本文中出現的TF皆為TensorFlow的簡稱。

先說兩句題外話吧,TensorFlow 前兩天熱熱鬧鬧的釋出了正式版r1.0,可感覺自己才剛剛上手 r0.12,這個時代發展的太快,腳步是一刻也不能停啊~

但是不得不吐槽 TensorFlow的向下相容做的實在不太友好,每次更新完版本,以前的程式碼就跑不動,各種提示您使用的函式已經不存在。。。

程式碼積攢的越來越多,全部針對新版本翻改一遍,工程真是浩大。但是喜新厭舊,手賤如我,每次都忍不住點了更新。不過這次忍的還算不錯,到目前還沒更新,繼續忍住

在之前的文章中,我介紹瞭如何實現 TensorFlow官網的Mobile教程:
【將Tensorflow移植到安卓手機,實現物體識別、行人檢測和影象風格遷移】


但在那個教程中,TensorFlow提供了完整的、已經構建好的Android專案,我們需要做的總結下來只有3步:1、搭建環境;2、編譯;3、安裝到手機

這當然還不夠,我們的最終目的當然是要為我所用,所以怎樣才能移植自己訓練好的TF模型到安卓手機呢?換句話說,怎樣將訓練好的模型放入Android專案中並進行成功編譯?又或者怎樣建立自己的Android Tensorflow專案?

PS:
之前沒有安卓開發的經驗,純粹是為了實現將TF模型移植到手機才開始上手,目前屬於入門級小白,如有錯誤之處,歡迎批評指正!

手機呼叫TF模型的過程簡介:

1、 儲存訓練完畢的TF模型
2、 在Android專案中匯入TF模型、匯入Android平臺呼叫TF模型需要的jar包和so檔案 (它們負責TF模型的解析和運算)
3、定義變數、儲存資料,通過jar包提供的介面進行模型的呼叫

環境

TensorFlow版本: r0.12
python 版本:2.7
Python IDE: Spyder
Android IDE : Android Studio

移植過程

我們以mnist資料集上自己訓練的一個影象識別模型為例,進行講解

一、 在使用python程式碼編寫的TF模型定義中為模型的輸入層和輸出層Tensor Variable分別指定名字(通過形參 ‘name’)

X = tf.placeholder(tf.float32, shape = […], name=‘input’)  //網路的輸入
Y = tf.nn.softmax(tf.matmul
(f, out_weights) + out_biases, name=’output’) //網路的輸出

名字可以隨便起,以方便好記為主,後面還會反覆用到。我起的是input和output。

二、 將使用TensorFlow訓練好的模型儲存為.pb檔案

在模型訓練結束後的程式碼位置,新增下述兩句程式碼,可將模型儲存為.pb檔案

output_graph_def = tf.graph_until.convert_variables_to_constants(session, session.graph_def, output_node_names=[‘output’])
//形參output_node_names用於指定輸出的節點名稱

貼一個說明文件,幫助大家進一步瞭解這個函式

這裡寫圖片描述

with tf.gfile.FastGFile(model\mnist.pb, mode = ’wb’) as f:
    f.write(output_graph_def.SerializeToString())

第一個引數用於指定輸出的檔案存放路徑、檔名及格式。我把它放在與程式碼同級目錄的model檔案下,取名為mnist.pb

第二個引數 mode用於指定檔案操作的模式,’wb’中w代表寫檔案,b代表將資料以二進位制方式寫入檔案。

如果不指明‘b’,則預設會以文字txt方式寫入檔案。現在TF還不支援對文字格式.pb檔案的解析,在呼叫時會出現報錯。

注:
1)、不能使用 tf.train.write_graph()儲存模型,因為它只是儲存了模型的結構,並不儲存訓練完畢的引數值
2)、不能使用 tf.train.saver()儲存模型,因為它只是儲存了網路中的引數值,並不儲存模型的結構。
很顯然,我們需要的是既儲存模型的結構,又儲存模型中每個引數的值。以上兩者皆不符合。

三、生成在Android平臺上呼叫tensorflow 模型需要的jar包和so檔案
1) 從github下載TensorFlow的專案原始碼

3) 參考如下圖的官方教程,生成Android上呼叫TF模型需要的so檔案和jar包
這裡寫圖片描述

四、安裝Android Studio,建立Android 專案

Android Studio安裝完畢後,還需要搭建環境。搭建過程可參考我的另一篇文章:

五、新增資源到專案

1) 將(二)步生成的.pb檔案放入專案中
開啟 Project view ,app/src/main/assets。
若不存在assets目錄,右鍵main->new->folder->Assets Folder

2) 新增(三)步生成的jar包
開啟Project view,將jar包拷貝到app->libs下
選中jar檔案,右鍵 add as library

3) 新增(三)生成的so檔案
開啟 Project view,將.so檔案拷貝到 app/src/main/jniLibs下(jniLibs資料夾若沒有則新建)

如果我講的不太明白的話,可自行谷歌搜尋“如何在 Android studio中新增引用 jar檔案和so檔案”

六、建立介面,實現呼叫

1) 匯入jar包和so檔案
在需要呼叫模型的.java檔案中,匯入jar包:

import org.tensorflow.contrib.android.TensorFlowInferenceInterface

在該java類定義的首行,匯入so檔案:

{
    System.loadLibrary(“tensorflow_inference”)
}

2)定義變數及物件

private static final String MODEL_FILE = “file:///android_asset/mnist.pb”   //模型存放路徑
private static final String INPUT_NODE = “input”;       //模型中輸入變數的名稱
private static final String INPUT_NODE = “output”;  //模型中輸出變數的名稱
private static final int NUM_CLASSES = 10;  //樣本集的類別數量,mnist資料集對應10

private static final int HEIGHT = 24;       //輸入圖片的畫素高
private static final int WIDTH = 24;        //輸入圖片的畫素寬
private static final int CHANNEL = 3;    //輸入圖片的通道數:RGB

private floats inputs = new float[HEIGHT*WIDTH*CHANNEL];    //用於儲存的模型輸入資料
private floats outputs = new float[NUM_CLASSES];    //用於儲存模型的輸出資料

2)Tensorflow 介面初始化

private TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface();   //介面定義
inferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE);  //介面初始化

在完成上述兩步之後,就可以反覆呼叫模型。
在每次呼叫前,先將待輸入的資料按順序存放進 inputs 變數中,然後執行下述三個語句。

3)TF模型的呼叫

inferenceInterface.fillNodeFloat(INPUT_NODE, new int[]{1, HEIGHT, WIDTH, CHANNEL}, inputs);  //送入輸入資料
inferenceInterface.runInference(new String[]{OUTPUT_NODE});     //進行模型的推理
inferenceInterface.readNodeFloat(OUTPUT_NODE, outputs); //獲取輸出資料

然後接下來的主要工作就是安卓專案的編譯以及將編譯完的apk檔案安裝到手機,這部分內容與一般的安卓專案並無區別。這些內容在我的另一篇文章中也有所提及:

為了便於大家理解,我寫的程式碼比較面向過程。當然放在java環境下,還是要多多從面向物件的角度出發,合理的封裝,提高程式碼的複用性。

我在CSDN釋出的文章,也會同步釋出到公眾號 CV_life。

歡迎有興趣的同學訂閱

這裡寫圖片描述