1. 程式人生 > >將keras或tensorflow模型遷移到android端(AndroidStudio)

將keras或tensorflow模型遷移到android端(AndroidStudio)

經歷了多次採坑,終於是現實了將keras模型遷移到android端上。

網上的一些教程實在太少,而且我也是走了很多彎路,有很多是使用Bazel工具把TensoFlow編譯成.so庫檔案和jar包,再進行Android配置,實現模型移植。但是。。。我可能太笨了沒成功。。放棄了。

非常激動,終於實現了遷移。但是原文中也是存在著一些小陷阱,所以我在此基礎上完善了一下教程。

本人的環境如下:

  • Windows 10
  • Python3.6
  • TensoFlow 1.6.0(2018年3月)
  • Android Studio 3.2(2018年10月)
  • AndroidSDK  android9.0  API28

總之把模型部署到安卓裝置上總體的步驟如下:

  1. 將訓練好的模型轉換成 TensorFlow 格式;

  2. 向安卓應用新增 TensorFlow Mobile 依賴項;

  3. 編寫相關的 Java 程式碼,在你的應用中使用 TensorFlow 模型執行推斷。

一、將訓練好的模型轉換成 TensorFlow 格式

如果本來就是tensorflow訓練出的模型的話,可以跳過這一步直接進行AndroidStudio的配置就好啦。

假設有了keras儲存的h5模型檔案,則需要以下轉換:

新建python指令碼檔案:(本程式碼實現了keras的h5模型轉換到tensorflow的pd模型格式,對應著keras_to_tensorflow的函式)

from keras.models import Model
from keras.layers import *
import os
import tensorflow as tf
 
 
def keras_to_tensorflow(keras_model, output_dir, model_name,out_prefix="output_", log_tensorboard=True):
 
    if os.path.exists(output_dir) == False:
        os.mkdir(output_dir)
 
    out_nodes = []
 
    for i in range(len(keras_model.outputs)):
        out_nodes.append(out_prefix + str(i + 1))
        tf.identity(keras_model.output[i], out_prefix + str(i + 1))
 
    sess = K.get_session()
 
    from tensorflow.python.framework import graph_util, graph_io
 
    init_graph = sess.graph.as_graph_def()
 
    main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
 
    graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
 
    if log_tensorboard:
        from tensorflow.python.tools import import_pb_to_tensorboard
 
        import_pb_to_tensorboard.import_to_tensorboard(
            os.path.join(output_dir, model_name),
            output_dir)
 
 
"""
We explicitly redefine the Squeezent architecture since Keras has no predefined Squeezenet
"""
 
def squeezenet_fire_module(input, input_channel_small=16, input_channel_large=64):
 
    channel_axis = 3
 
    input = Conv2D(input_channel_small, (1,1), padding="valid" )(input)
    input = Activation("relu")(input)
 
    input_branch_1 = Conv2D(input_channel_large, (1,1), padding="valid" )(input)
    input_branch_1 = Activation("relu")(input_branch_1)
 
    input_branch_2 = Conv2D(input_channel_large, (3, 3), padding="same")(input)
    input_branch_2 = Activation("relu")(input_branch_2)
 
    input = concatenate([input_branch_1, input_branch_2], axis=channel_axis)
 
    return input
 
 
def SqueezeNet(input_shape=(224,224,3)):
 
 
 
    image_input = Input(shape=input_shape)
 
 
    network = Conv2D(64, (3,3), strides=(2,2), padding="valid")(image_input)
    network = Activation("relu")(network)
    network = MaxPool2D( pool_size=(3,3) , strides=(2,2))(network)
 
    network = squeezenet_fire_module(input=network, input_channel_small=16, input_channel_large=64)
    network = squeezenet_fire_module(input=network, input_channel_small=16, input_channel_large=64)
    network = MaxPool2D(pool_size=(3,3), strides=(2,2))(network)
 
    network = squeezenet_fire_module(input=network, input_channel_small=32, input_channel_large=128)
    network = squeezenet_fire_module(input=network, input_channel_small=32, input_channel_large=128)
    network = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(network)
 
    network = squeezenet_fire_module(input=network, input_channel_small=48, input_channel_large=192)
    network = squeezenet_fire_module(input=network, input_channel_small=48, input_channel_large=192)
    network = squeezenet_fire_module(input=network, input_channel_small=64, input_channel_large=256)
    network = squeezenet_fire_module(input=network, input_channel_small=64, input_channel_large=256)
 
    #Remove layers like Dropout and BatchNormalization, they are only needed in training
    #network = Dropout(0.5)(network)
 
    network = Conv2D(1000, kernel_size=(1,1), padding="valid", name="last_conv")(network)
    network = Activation("relu")(network)
 
    network = GlobalAvgPool2D()(network)
    network = Activation("softmax",name="output")(network)
 
 
    input_image = image_input
    model = Model(inputs=input_image, outputs=network)
 
    return model
 
 
keras_model = SqueezeNet()
 
keras_model.load_weights("squeezenet.h5")
 
 
output_dir = os.path.join(os.getcwd(),"checkpoint")
 
keras_to_tensorflow(keras_model,output_dir=output_dir,model_name="squeezenet.pb")
 
print("MODEL SAVED")

二、配置AndroidStudio依賴:

請在 Android Studio 中建立一個新的工程。在你的 app:build.gradle 檔案中新增 TensorFlow Mobile 依賴

 implementation 'org.tensorflow:tensorflow-android:+'

筆者在此遇到過一些坑,不能下載tensorflow-android-1.11.0-rc1.aar檔案導致sync時失敗。但是不知為何過了幾天就突然好了。。(筆者嘗試過從網上直接下載tensorflow-android-1.11.0-rc1.aar這個檔案,然後放入lib檔案中,再進行新增依賴。但是AS突然不報錯了,所以我並沒有嘗試。遇到這個問題的小夥伴們可以嘗試一下這個解決方案)

三、android端程式碼編寫:

此時走到這步的話,你的環境就已經完全設定好了。剩下的就是java程式碼的書寫了,我是借用了上文提到的部落格裡的程式碼,但是原部落格程式碼中有Snackbar等控制元件導致編譯老是出錯,在安卓領域我還是剛入門能力還是不夠看不太懂。所以我就修改了下程式碼,只使用了最簡單的Button、TextView、ImageVIew控制元件。

1、在編寫程式碼進行實際推斷之前,你需要將轉換後的模型(squeezenet.pb)新增到應用程式的資原始檔夾中。在 Android Studio 中,右鍵點選你的專案,跳轉至「Add Folder」(新增資料夾)部分,並選擇「Assets Folder」(資原始檔夾)。這將在你的應用程式目錄中建立一個資原始檔夾。接下來,你需要將模型複製到資原始檔夾中。如下:
 

其中squeezenet.pb為tensorflow的模型檔案,testimage.jpg為要預測的圖片,labels.json為模型輸出數值後對應的label具體含義。

2、將一個新的 Java 類新增到專案的主程式包中,並將其命名為 ImageUtils,把下面的程式碼複製到其中。

package com.example.qing.byme_tensroflow;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Matrix;
import android.os.Environment;
import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import org.json.*;

/**
 * Utility class for manipulating images.
 **/
public class ImageUtils {
    /**
     * Returns a transformation matrix from one reference frame into another.
     * Handles cropping (if maintaining aspect ratio is desired) and rotation.
     *
     * @param srcWidth Width of source frame.
     * @param srcHeight Height of source frame.
     * @param dstWidth Width of destination frame.
     * @param dstHeight Height of destination frame.
     * @param applyRotation Amount of rotation to apply from one frame to another.
     *  Must be a multiple of 90.
     * @param maintainAspectRatio If true, will ensure that scaling in x and y remains constant,
     * cropping the image if necessary.
     * @return The transformation fulfilling the desired requirements.
     */
    public static Matrix getTransformationMatrix(
            final int srcWidth,
            final int srcHeight,
            final int dstWidth,
            final int dstHeight,
            final int applyRotation,
            final boolean maintainAspectRatio) {
        final Matrix matrix = new Matrix();

        if (applyRotation != 0) {
            // Translate so center of image is at origin.
            matrix.postTranslate(-srcWidth / 2.0f, -srcHeight / 2.0f);

            // Rotate around origin.
            matrix.postRotate(applyRotation);
        }

        // Account for the already applied rotation, if any, and then determine how
        // much scaling is needed for each axis.
        final boolean transpose = (Math.abs(applyRotation) + 90) % 180 == 0;

        final int inWidth = transpose ? srcHeight : srcWidth;
        final int inHeight = transpose ? srcWidth : srcHeight;

        // Apply scaling if necessary.
        if (inWidth != dstWidth || inHeight != dstHeight) {
            final float scaleFactorX = dstWidth / (float) inWidth;
            final float scaleFactorY = dstHeight / (float) inHeight;

            if (maintainAspectRatio) {
                // Scale by minimum factor so that dst is filled completely while
                // maintaining the aspect ratio. Some image may fall off the edge.
                final float scaleFactor = Math.max(scaleFactorX, scaleFactorY);
                matrix.postScale(scaleFactor, scaleFactor);
            } else {
                // Scale exactly to fill dst from src.
                matrix.postScale(scaleFactorX, scaleFactorY);
            }
        }

        if (applyRotation != 0) {
            // Translate back from origin centered reference to destination frame.
            matrix.postTranslate(dstWidth / 2.0f, dstHeight / 2.0f);
        }

        return matrix;
    }


    public static Bitmap processBitmap(Bitmap source,int size){

        int image_height = source.getHeight();
        int image_width = source.getWidth();

        Bitmap croppedBitmap = Bitmap.createBitmap(size, size, Bitmap.Config.ARGB_8888);

        Matrix frameToCropTransformations = getTransformationMatrix(image_width,image_height,size,size,0,false);
        Matrix cropToFrameTransformations = new Matrix();
        frameToCropTransformations.invert(cropToFrameTransformations);

        final Canvas canvas = new Canvas(croppedBitmap);
        canvas.drawBitmap(source, frameToCropTransformations, null);

        return croppedBitmap;


    }

    public static float[] normalizeBitmap(Bitmap source,int size,float mean,float std){

        float[] output = new float[size * size * 3];

        int[] intValues = new int[source.getHeight() * source.getWidth()];

        source.getPixels(intValues, 0, source.getWidth(), 0, 0, source.getWidth(), source.getHeight());
        for (int i = 0; i < intValues.length; ++i) {
            final int val = intValues[i];
            output[i * 3] = (((val >> 16) & 0xFF) - mean)/std;
            output[i * 3 + 1] = (((val >> 8) & 0xFF) - mean)/std;
            output[i * 3 + 2] = ((val & 0xFF) - mean)/std;
        }

        return output;

    }

    public static Object[] argmax(float[] array){


        int best = -1;
        float best_confidence = 0.0f;

        for(int i = 0;i < array.length;i++){

            float value = array[i];

            if (value > best_confidence){

                best_confidence = value;
                best = i;
            }
        }



        return new Object[]{best,best_confidence};


    }


    public static String getLabel( InputStream jsonStream,int index){
        String label = "";
        try {

            byte[] jsonData = new byte[jsonStream.available()];
            jsonStream.read(jsonData);
            jsonStream.close();

            String jsonString = new String(jsonData,"utf-8");

            JSONObject object = new JSONObject(jsonString);

            label = object.getString(String.valueOf(index));



        }
        catch (Exception e){


        }
        return label;
    }
}

 假如只是用來開發的話對於ImageUtils這個類不需要理解程式碼實現,會用就好啦。

2、在你的主活動(main activity)新增程式碼。它們將被用於顯示影象和預測結果。

package com.example.qing.byme_tensroflow;

import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.AsyncTask;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.AsyncTask;
import android.os.Bundle;
import android.renderscript.ScriptGroup;
import android.support.v7.app.AppCompatActivity;
import android.support.v7.widget.Toolbar;
import android.util.JsonReader;
import android.view.View;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;
import org.json.*;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

import java.io.FileInputStream;
import java.io.InputStream;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

import java.io.InputStream;

public class MainActivity extends AppCompatActivity {
    /*
     * 在需要呼叫TensoFlow的地方,載入so庫“System.loadLibrary("tensorflow_inference");
     * 並”import org.tensorflow.contrib.android.TensorFlowInferenceInterface;就可以使用了
     * */
    //Load the tensorflow inference library
    //static{}(即static塊),會在類被載入的時候執行且僅會被執行一次,一般用來初始化靜態變數和呼叫靜態方法。
    static {
        System.loadLibrary("tensorflow_inference");
    }

    //PATH TO OUR MODEL FILE AND NAMES OF THE INPUT AND OUTPUT NODES
    //各節點名稱
    private String MODEL_PATH = "file:///android_asset/squeezenet.pb";
    private String INPUT_NAME = "input_1";
    private String OUTPUT_NAME = "output_1";
    private TensorFlowInferenceInterface tf;

    //ARRAY TO HOLD THE PREDICTIONS AND FLOAT VALUES TO HOLD THE IMAGE DATA
    //儲存圖片和圖片尺寸的
    float[] PREDICTIONS = new float[1000];
    private float[] floatValues;
    private int[] INPUT_SIZE = {224,224,3};

    ImageView imageView;
    TextView resultView;
    Button buttonSub;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        tf = new TensorFlowInferenceInterface(getAssets(),MODEL_PATH);

        imageView=(ImageView)findViewById(R.id.imageView1);
        resultView=(TextView)findViewById(R.id.text_show);
        buttonSub=(Button)findViewById(R.id.button1);

        buttonSub.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                try{
                    InputStream imageStream = getAssets().open("testimage.jpg");
                    Bitmap bitmap = BitmapFactory.decodeStream(imageStream);
                    imageView.setImageBitmap(bitmap);

                    predict1(bitmap);

                }catch(Exception e){

                }


            }
        });

    }

    //FUNCTION TO COMPUTE THE MAXIMUM PREDICTION AND ITS CONFIDENCE
    public Object[] argmax(float[] array){

        int best = -1;
        float best_confidence = 0.0f;
        for(int i = 0;i < array.length;i++){
            float value = array[i];
            if (value > best_confidence){
                best_confidence = value;
                best = i;
            }
        }
        return new Object[]{best,best_confidence};
    }



    public void predict(final Bitmap bitmap){

        //Runs inference in background thread
        new AsyncTask<Integer,Integer,Integer>(){

            @Override
            protected Integer doInBackground(Integer ...params){
                //Resize the image into 224 x 224
                Bitmap resized_image = ImageUtils.processBitmap(bitmap,224);

                //Normalize the pixels
                floatValues = ImageUtils.normalizeBitmap(resized_image,224,127.5f,1.0f);

                //Pass input into the tensorflow
                tf.feed(INPUT_NAME,floatValues,1,224,224,3);

                //compute predictions
                tf.run(new String[]{OUTPUT_NAME});

                //copy the output into the PREDICTIONS array
                tf.fetch(OUTPUT_NAME,PREDICTIONS);

                //Obtained highest prediction
                Object[] results = argmax(PREDICTIONS);

                int class_index = (Integer) results[0];
                float confidence = (Float) results[1];

                try{
                    final String conf = String.valueOf(confidence * 100).substring(0,5);
                    //Convert predicted class index into actual label name
                    final String label = ImageUtils.getLabel(getAssets().open("labels.json"),class_index);
                    //Display result on UI
                    runOnUiThread(new Runnable() {
                        @Override
                        public void run() {
                            resultView.setText(label + " : " + conf + "%");
                        }
                    });
                } catch (Exception e){
                }

                return 0;
            }

        }.execute(0);

    }
}

其中模型的推理部分放入到了predic函式中,並且將其耗時操作加入到了子執行緒中。

4、如果以上都進展順利,點選Run,基本就大功告成啦!

package com.example.jinquan.pan.mnist_ensorflow_androiddemo;

import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Color;
import android.graphics.Matrix;
import android.util.Log;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;


public class PredictionTF {
    private static final String TAG = "PredictionTF";
    //設定模型輸入/輸出節點的資料維度
    private static final int IN_COL = 1;
    private static final int IN_ROW = 28*28;
    private static final int OUT_COL = 1;
    private static final int OUT_ROW = 1;
    //模型中輸入變數的名稱
    private static final String inputName = "input/x_input";
    //模型中輸出變數的名稱
    private static final String outputName = "output";

    TensorFlowInferenceInterface inferenceInterface;
    static {
        //載入libtensorflow_inference.so庫檔案
        System.loadLibrary("tensorflow_inference");
        Log.e(TAG,"libtensorflow_inference.so庫載入成功");
    }

    PredictionTF(AssetManager assetManager, String modePath) {
        //初始化TensorFlowInferenceInterface物件
        inferenceInterface = new TensorFlowInferenceInterface(assetManager,modePath);
        Log.e(TAG,"TensoFlow模型檔案載入成功");
    }

    /**
     *  利用訓練好的TensoFlow模型預測結果
     * @param bitmap 輸入被測試的bitmap圖
     * @return 返回預測結果,int陣列
     */
    public int[] getPredict(Bitmap bitmap) {
        float[] inputdata = bitmapToFloatArray(bitmap,28,28);//需要將圖片縮放帶28*28
        //將資料feed給tensorflow的輸入節點
        inferenceInterface.feed(inputName, inputdata, IN_COL, IN_ROW);
        //執行tensorflow
        String[] outputNames = new String[] {outputName};
        inferenceInterface.run(outputNames);
        ///獲取輸出節點的輸出資訊
        int[] outputs = new int[OUT_COL*OUT_ROW]; //用於儲存模型的輸出資料
        inferenceInterface.fetch(outputName, outputs);
        return outputs;
    }

    /**
     * 將bitmap轉為(按行優先)一個float陣列,並且每個畫素點都歸一化到0~1之間。
     * @param bitmap 輸入被測試的bitmap圖片
     * @param rx 將圖片縮放到指定的大小(列)->28
     * @param ry 將圖片縮放到指定的大小(行)->28
     * @return   返回歸一化後的一維float陣列 ->28*28
     */
    public static float[] bitmapToFloatArray(Bitmap bitmap, int rx, int ry){
        int height = bitmap.getHeight();
        int width = bitmap.getWidth();
        // 計算縮放比例
        float scaleWidth = ((float) rx) / width;
        float scaleHeight = ((float) ry) / height;
        Matrix matrix = new Matrix();
        matrix.postScale(scaleWidth, scaleHeight);
        bitmap = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);
        Log.i(TAG,"bitmap width:"+bitmap.getWidth()+",height:"+bitmap.getHeight());
        Log.i(TAG,"bitmap.getConfig():"+bitmap.getConfig());
        height = bitmap.getHeight();
        width = bitmap.getWidth();
        float[] result = new float[height*width];
        int k = 0;
        //行優先
        for(int j = 0;j < height;j++){
            for (int i = 0;i < width;i++){
                int argb = bitmap.getPixel(i,j);
                int r = Color.red(argb);
                int g = Color.green(argb);
                int b = Color.blue(argb);
                int a = Color.alpha(argb);
                //由於是灰度圖,所以r,g,b分量是相等的。
                assert(r==g && g==b);
//                Log.i(TAG,i+","+j+" : argb = "+argb+", a="+a+", r="+r+", g="+g+", b="+b);
                result[k++] = r / 255.0f;
            }
        }
        return result;
    }
}
package com.example.jinquan.pan.mnist_ensorflow_androiddemo;

import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
import android.widget.ImageView;
import android.widget.TextView;

public class MainActivity extends AppCompatActivity {

    // Used to load the 'native-lib' library on application startup.
    static {
        System.loadLibrary("native-lib");
    }

    private static final String TAG = "MainActivity";
    private static final String MODEL_FILE = "file:///android_asset/mnist.pb"; //模型存放路徑
    TextView txt;
    TextView tv;
    ImageView imageView;
    Bitmap bitmap;
    PredictionTF preTF;
    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        // Example of a call to a native method
        tv = (TextView) findViewById(R.id.sample_text);
        txt=(TextView)findViewById(R.id.txt_id);
        imageView =(ImageView)findViewById(R.id.imageView1);
        bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.test_image);
        imageView.setImageBitmap(bitmap);
        preTF =new PredictionTF(getAssets(),MODEL_FILE);//輸入模型存放路徑,並載入TensoFlow模型
    }

    public void click01(View v){
        String res="預測結果為:";
        int[] result= preTF.getPredict(bitmap);
        for (int i=0;i<result.length;i++){
            Log.i(TAG, res+result[i] );
            res=res+String.valueOf(result[i])+" ";
        }
        txt.setText(res);
        tv.setText(stringFromJNI());
    }
    /**
     * A native method that is implemented by the 'native-lib' native library,
     * which is packaged with this application.
     */
    public native String stringFromJNI();
}