1. 程式人生 > >使用TensorFlow Lite在Android手機上實現影象分類

使用TensorFlow Lite在Android手機上實現影象分類

*本篇文章已授權微信公眾號 guolin_blog (郭霖)獨家釋出

前言

TensorFlow Lite是一款專門針對移動裝置的深度學習框架,移動裝置深度學習框架是部署在手機或者樹莓派等小型移動裝置上的深度學習框架,可以使用訓練好的模型在手機等裝置上完成推理任務。這一類框架的出現,可以使得一些推理的任務可以在本地執行,不需要再呼叫伺服器的網路介面,大大減少了預測時間。在前幾篇文章中已經介紹了百度的paddle-mobile,小米的mace,還有騰訊的ncnn。這在本章中我們將介紹谷歌的TensorFlow Lite。

轉換模型

手機上執行預測,首先需要一個訓練好的模型,這個模型不能是TensorFlow原來格式的模型,TensorFlow Lite使用的模型格式是另一種格式的模型。下面就介紹如何使用這個格式的模型。

獲取模型主要有兩種方法,第一種是在訓練的時候就儲存tflite模型,另外一種就是使用其他格式的TensorFlow模型轉換成tflite模型。
1、最方便的就是在訓練的時候儲存tflite格式的模型,主要是使用到tf.contrib.lite.toco_convert()介面,下面就是一個簡單的例子:

import tensorflow as tf

img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant(
[1., 4., 4.]) out = tf.identity(val, name="out") with tf.Session() as sess: tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out]) open("converteds_model.tflite", "wb").write(tflite_model)

最後獲得的converteds_model.tflite檔案就可以直接在TensorFlow Lite上使用。

2、第二種就是把tensorflow儲存的其他模型轉換成tflite

,我們可以在以下的連結下載模型:

上面提供的模型同時也包括了tflite模型,我們可以直接拿來使用,但是我們也可以使用其他格式的模型來轉換。比如我們下載一個mobilenet_v1_1.0_224.tgz,解壓之後獲得以下檔案:

mobilenet_v1_1.0_224.ckpt.data-00000-of-00001  mobilenet_v1_1.0_224_eval.pbtxt  mobilenet_v1_1.0_224.tflite
mobilenet_v1_1.0_224.ckpt.index                mobilenet_v1_1.0_224_frozen.pb
mobilenet_v1_1.0_224.ckpt.meta                 mobilenet_v1_1.0_224_info.txt

然後克隆TensorFlow的原始碼:

git clone https://github.com/tensorflow/tensorflow.git

接著編譯轉換工具,這個編譯時間可能比較長:

cd tensorflow/
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/contrib/lite/toco:toco

獲得到轉換工具之後,我們就可以開始轉換模型了,以下操作是凍結圖。

  • input_graph對應的是.pb檔案;
  • input_checkpoint對應的是mobilenet_v1_1.0_224.ckpt.data-00000-of-00001,但是在使用的使用是去掉字尾名的。
  • output_node_names這個可以在mobilenet_v1_1.0_224_info.txt中獲取。

不過要注意的是我們下載的模型已經是凍結過來,所以不用再執行這個操作。但如果是其他的模型,要先凍結圖,然後再執行之後的操作。

./freeze_graph --input_graph=/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_frozen.pb \
  --input_checkpoint=/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt \
  --input_binary=true \
  --output_graph=/tmp/frozen_mobilenet_v1_224.pb \
  --output_node_names=MobilenetV1/Predictions/Reshape_1

以下操作就是把已經凍結的圖轉換成.tflite

  • input_file是已經凍結的圖;
  • output_file是轉換後輸出的路徑;
  • output_arrays這個可以在mobilenet_v1_1.0_224_info.txt中獲取;
  • input_shapes這個是預測資料的shape
./toco --input_file=/tmp/mobilenet_v1_1.0_224_frozen.pb \
  --input_format=TENSORFLOW_GRAPHDEF \
  --output_format=TFLITE \
  --output_file=/tmp/mobilenet_v1_1.0_224.tflite \
  --inference_type=FLOAT \
  --input_type=FLOAT \
  --input_arrays=input \
  --output_arrays=MobilenetV1/Predictions/Reshape_1 \
  --input_shapes=1,224,224,3

經過上面的步驟就可以獲取到mobilenet_v1_1.0_224.tflite模型了,之後我們會在Android專案中使用它。

開發Android專案

有了上面的模型之後,我們就使用Android Studio建立一個Android專案,一路預設就可以了,並不需要C++的支援,因為我們使用到的TensorFlow Lite是Java程式碼的,開發起來非常方便。

1、建立完成之後,在app目錄下的build.gradle配置檔案加上以下配置資訊:
dependencies下加上包的引用,第一個是圖片載入框架Glide,第二個就是我們這個專案的核心TensorFlow Lite:

    implementation 'com.github.bumptech.glide:glide:4.3.1'
    implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'

然後在android下加上以下程式碼,這個主要是限制不要對tensorflow lite的模型進行壓縮,壓縮之後就無法載入模型了:

    //set no compress models
    aaptOptions {
        noCompress "tflite"
    }

2、在main目錄下建立assets資料夾,這個資料夾主要是存放tflite模型和label名稱檔案。

3、以下是主介面的程式碼MainActivity.java,這個程式碼比較長,我們來分析這段程式碼,重要的方法介紹如下:

  • loadModelFile()方法是把模型檔案讀取成MappedByteBuffer,之後給Interpreter類初始化模型,這個模型存放在mainassets目錄下。
  • load_model()方法是載入模型,並得到一個物件tflite,之後就是使用這個物件來預測影象,同時可以使用這個物件設定一些引數,比如設定使用的執行緒數量tflite.setNumThreads(4);
  • showDialog()方法是顯示彈窗,通過這個彈窗的選擇不同的模型。
  • readCacheLabelFromLocalFile()方法是讀取檔案種分類標籤對應的名稱,這個檔案比較長,可以參考這篇文章獲取標籤名稱,也可以下載筆者的專案,裡面有對用的檔案。這個檔案cacheLabel.txt跟模型一樣存放在assets目錄下。
  • predict_image()方法是預測圖片並顯示結果的,預測的流程是:獲取圖片的路徑,然後使用對圖片進行壓縮,之後把圖片轉換成ByteBuffer格式的資料,最後呼叫tflite.run()方法進行預測。
  • get_max_result()方法是獲取最大概率的標籤。
package com.yeyupiaoling.testtflite;

import android.Manifest;
import android.app.Activity;
import android.content.DialogInterface;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.net.Uri;
import android.os.Bundle;
import android.support.annotation.NonNull;
import android.support.annotation.Nullable;
import android.support.v4.app.ActivityCompat;
import android.support.v4.content.ContextCompat;
import android.support.v7.app.AlertDialog;
import android.support.v7.app.AppCompatActivity;
import android.text.method.ScrollingMovementMethod;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;

import com.bumptech.glide.Glide;
import com.bumptech.glide.load.engine.DiskCacheStrategy;
import com.bumptech.glide.request.RequestOptions;
import org.tensorflow.lite.Interpreter;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.List;

public class MainActivity extends AppCompatActivity {
    private static final String TAG = MainActivity.class.getName();
    private static final int USE_PHOTO = 1001;
    private static final int START_CAMERA = 1002;
    private String camera_image_path;
    private ImageView show_image;
    private TextView result_text;
    private String assets_path = "lite_images";
    private boolean load_result = false;
    private int[] ddims = {1, 3, 224, 224};
    private int model_index = 0;
    private List<String> resultLabel = new ArrayList<>();
    private Interpreter tflite = null;

    private static final String[] PADDLE_MODEL = {
            "mobilenet_v1",
            "mobilenet_v2"
    };


    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        init_view();
        readCacheLabelFromLocalFile();
    }

    // initialize view
    private void init_view() {
        request_permissions();
        show_image = (ImageView) findViewById(R.id.show_image);
        result_text = (TextView) findViewById(R.id.result_text);
        result_text.setMovementMethod(ScrollingMovementMethod.getInstance());
        Button load_model = (Button) findViewById(R.id.load_model);
        Button use_photo = (Button) findViewById(R.id.use_photo);
        Button start_photo = (Button) findViewById(R.id.start_camera);

        load_model.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                showDialog();
            }
        });

        // use photo click
        use_photo.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                if (!load_result) {
                    Toast.makeText(MainActivity.this, "never load model", Toast.LENGTH_SHORT).show();
                    return;
                }
                PhotoUtil.use_photo(MainActivity.this, USE_PHOTO);
            }
        });

        // start camera click
        start_photo.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                if (!load_result) {
                    Toast.makeText(MainActivity.this, "never load model", Toast.LENGTH_SHORT).show();
                    return;
                }
                camera_image_path = PhotoUtil.start_camera(MainActivity.this, START_CAMERA);
            }
        });
    }

    /**
     * Memory-map the model file in Assets.
     */
    private MappedByteBuffer loadModelFile(String model) throws IOException {
        AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite");
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }


    // load infer model
    private void load_model(String model) {
        try {
            tflite = new Interpreter(loadModelFile(model));
            Toast.makeText(MainActivity.this, model + " model load success", Toast.LENGTH_SHORT).show();
            Log.d(TAG, model + " model load success");
            tflite.setNumThreads(4);
            load_result = true;
        } catch (IOException e) {
            Toast.makeText(MainActivity.this, model + " model load fail", Toast.LENGTH_SHORT).show();
            Log.d(TAG, model + " model load fail");
            load_result = false;
            e.printStackTrace();
        }
    }

    public void showDialog() {
        AlertDialog.Builder builder = new AlertDialog.Builder(MainActivity.this);

        // set dialog title
        builder.setTitle("Please select model");

        // set dialog icon
        builder.setIcon(android.R.drawable.ic_dialog_alert);

        // able click other will cancel
        builder.setCancelable(true);

        // cancel button
        builder.setNegativeButton("cancel", null);

        // set list
        builder.setSingleChoiceItems(PADDLE_MODEL, model_index, new DialogInterface.OnClickListener() {
            @Override
            public void onClick(DialogInterface dialog, int which) {
                model_index = which;
                load_model(PADDLE_MODEL[model_index]);
                dialog.dismiss();
            }
        });

        // show dialog
        builder.show();
    }


    private void readCacheLabelFromLocalFile() {
        try {
            AssetManager assetManager = getApplicationContext().getAssets();
            BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open("cacheLabel.txt")));
            String readLine = null;
            while ((readLine = reader.readLine()) != null) {
                resultLabel.add(readLine);
            }
            reader.close();
        } catch (Exception e) {
            Log.e("labelCache", "error " + e);
        }
    }

    @Override
    protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) {
        String image_path;
        RequestOptions options = new RequestOptions().skipMemoryCache(true).diskCacheStrategy(DiskCacheStrategy.NONE);
        if (resultCode == Activity.RESULT_OK) {
            switch (requestCode) {
                case USE_PHOTO:
                    if (data == null) {
                        Log.w(TAG, "user photo data is null");
                        return;
                    }
                    Uri image_uri = data.getData();
                    Glide.with(MainActivity.this).load(image_uri).apply(options).into(show_image);
                    // get image path from uri
                    image_path = PhotoUtil.get_path_from_URI(MainActivity.this, image_uri);
                    // predict image
                    predict_image(image_path);
                    break;
                case START_CAMERA:
                    // show photo
                    Glide.with(MainActivity.this).load(camera_image_path).apply(options).into(show_image);
                    // predict image
                    predict_image(camera_image_path);
                    break;
            }
        }
    }

    //  predict image
    private void predict_image(String image_path) {
        // picture to float array
        Bitmap bmp = PhotoUtil.getScaleBitmap(image_path);
        ByteBuffer inputData = PhotoUtil.getScaledMatrix(bmp, ddims);
        try {
            // Data format conversion takes too long
            // Log.d("inputData", Arrays.toString(inputData));
            float[][] labelProbArray = new float[1][1001];
            long start = System.currentTimeMillis();
            // get predict result
            tflite.run(inputData, labelProbArray);
            long end = System.currentTimeMillis();
            long time = end - start;
            float[] results = new float[labelProbArray[0].length];
            System.arraycopy(labelProbArray[0], 0, results, 0, labelProbArray[0].length);
            // show predict result and time
            int r = get_max_result(results);
            String show_text = "result:" + r + "\nname:" + resultLabel.get(r) + "\nprobability:" + results[r] + "\ntime:" + time + "ms";
            result_text.setText(show_text);
        } catch (Exception e) {
            e.printStackTrace();
        }

    // get max probability label
    private int get_max_result(float[] result) {
        float probability = result[0];
        int r = 0;
        for (int i = 0; i < result.length; i++) {
            if (probability < result[i]) {
                probability = result[i];
                r = i;
            }
        }
        return r;
    }

    // request permissions
    private void request_permissions() {

        List<String> permissionList = new ArrayList<>();
        if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
            permissionList.add<