1. 程式人生 > >tensorflow移植到Android端,實現物體檢測自動拍照

tensorflow移植到Android端,實現物體檢測自動拍照

tensorflow移植到Android端實現物體檢測

一. 說明

1. tensorflow是什麼:

    是谷歌基於DistBelief進行研發的第二代人工智慧學習系統。

2. 為什麼要使用tensorflow在手機端進行物體檢測:

    前一篇部落格中講到,將opencv移植到android中,檢測到指定物體自動拍照,雖然說是功能確實可以實現,但是opencv畢竟比較落後了,識別的錯誤率還是很大的,於是便引入了tensorflow進行物體分類檢測,在很大程度上提高了識別率以及正確率。

本篇部落格主要講解一下,tensorflow是怎樣移植到手機端的,以及自定義的模型是怎樣訓練的。

二. 準備工作

1. 下載tensorflow專案(地址:https://github.com/tensorflow/tensorflow)

2. python環境(我是用的是python3.5的)

3. tensorflow安裝(有兩個版本,cpu版,gpu版,顯然gpu訓練模型時要快很多,這個視電腦配置而定吧,一般安裝anaconda,直接執行:conda install tensorflow安裝即可,我使用了tensorflow-gpu版本,需要nvidia顯示卡支援,命令:conda install tensorflow-gpu)

4. Android Studio (我使用了2.3.3版本)

5. 下載 libtensorflow_inference.so 以及 libandroid_tensorflow_inference_java.jar檔案,(這兩個檔案可以使用原始碼進行編譯生成),連結:https://pan.baidu.com/s/1tN_nNqfy6JC272J17VaWTg 密碼:boat

三. 訓練自定義的tensorflow模型

1. 準備資料集

使用tensorflow訓練模型,該分類的類別數必須大於等於2的(沒有背景這一類別),舉例:

如果需要進行識別人和狗兩種類別,那麼:需要準備圖片(只有狗在裡面的)100來張,放入dogs資料夾,同時準備圖片(只有人在裡面的)100來張,放入peoples資料夾,圖片越多訓練出來的模型越精確,每個類別100來張只能說勉強夠用,

將兩個類別的資料夾放置:

tensorflow_master/tensorflow/examples/image_retraining/data/train 中

data/train資料夾沒有的話,自行新建

2. 準備預訓練模型

訓練模型需要用到imagenet預訓練權重,4個檔案(classify_image_graph_def.pb,imagenet_2012_challenge_label_map_proto.pbtxt,imagenet_synset_to_human_label_map.txt,inception-2015-12-05.tgz),下載連結:連結:https://pan.baidu.com/s/1JlDbYy4NHD7qy3Or5lDtSg 密碼:i3jo

提前下載拷貝至 model資料夾下,沒有該資料夾請自行新建,否則會自動下載很慢的

3. 開始訓練

cd 進入tensorflow_master\tensorflow\examples\image_retraining資料夾:

執行命令:

python retrain.py --bottleneck_dir bottleneck --how_many_training_steps 4000 --model_dir model/ --output_graph output_graph.pb --output_labels output_labels.txt --image_dir data/train/

執行完畢會在tensorflow_master\tensorflow\examples\image_retraining資料夾下生成兩個檔案:

output_graph.pb  以及    output_labels.txt

4. 上一步驟中生成的模型不能直接放置到Android中,需要一步轉化:官方的解釋:

To use v3 Inception model, strip the DecodeJpeg Op from your retrained
  // model first:

cd 進入tensorflow_master\tensorflow\python\tools資料夾,將上步中生成的 output_graph.pb 檔案複製到改目錄下,執行命令:

python strip_unused.py --input_graph=output_graph.pb --output_graph=output.pb --input_node_names="Mul" --output_node_names="final_result" --input_binary=true

即可在改目錄下生成 output.pb 檔案。

至此,模型訓練完畢。

四. 整合Android專案

1. 新建專案後,在\app\src\main目錄下 新建assets以及jniLibs兩個目錄,將之前生成的 output.pb 以及 output_labels.txt檔案拷貝至assets資料夾下

2. 在jniLibs資料夾下新建armeabi-v7a 資料夾,將 libtensorflow_inference.so 拷貝至 jniLibs\armeabi-v7a 資料夾下

3. 將libandroid_tensorflow_inference_java.jar 新增至專案中,不會的直接搜尋 Android Studio新增jar。

4. 新建一個類(Classifier.Java):

import android.graphics.Bitmap;
import android.graphics.RectF;

import java.util.List;

/**
 * Created by amitshekhar on 06/03/17.
 */

/**
 * Generic interface for interacting with different recognition engines.
 */
public interface Classifier {
    /**
     * An immutable result returned by a Classifier describing what was recognized.
     */
    public class Recognition {
        /**
         * A unique identifier for what has been recognized. Specific to the class, not the instance of
         * the object.
         */
        private final String id;

        /**
         * Display name for the recognition.
         */
        private final String title;

        /**
         * A sortable score for how good the recognition is relative to others. Higher should be better.
         */
        private final Float confidence;

        /**
         * Optional location within the source image for the location of the recognized object.
         */
        private RectF location;

        public Recognition(
                final String id, final String title, final Float confidence, final RectF location) {
            this.id = id;
            this.title = title;
            this.confidence = confidence;
            this.location = location;
        }

        public String getId() {
            return id;
        }

        public String getTitle() {
            return title;
        }

        public Float getConfidence() {
            return confidence;
        }

        public RectF getLocation() {
            return new RectF(location);
        }

        public void setLocation(RectF location) {
            this.location = location;
        }

        @Override
        public String toString() {
            String resultString = "";
            if (id != null) {
                resultString += "[" + id + "] ";
            }

            if (title != null) {
                resultString += title + " ";
            }

            if (confidence != null) {
                resultString += String.format("(%.1f%%) ", confidence * 100.0f);
            }

            if (location != null) {
                resultString += location + " ";
            }

            return resultString.trim();
        }
    }

    List<Recognition> recognizeImage(Bitmap bitmap);

    void enableStatLogging(final boolean debug);

    String getStatString();

    void close();
}

5. 新建識別實現類 (TensorFlowImageClassifier.Java)
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.support.v4.os.TraceCompat;
import android.util.Log;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Vector;

import www.demo04.com.util.tensorflow.Classifier;

/**
 * Created by amitshekhar on 06/03/17.
 */

/**
 * A classifier specialized to label images using TensorFlow.
 */
public class TensorFlowImageClassifier implements Classifier {

    private static final String TAG = "ImageClassifier";

    // Only return this many results with at least this confidence.
    private static final int MAX_RESULTS = 2;
    private static final float THRESHOLD = 0.1f;

    // Config values.
    private String inputName;
    private String outputName;
    private int inputSize;
    private int imageMean;
    private float imageStd;

    // Pre-allocated buffers.
    private Vector<String> labels = new Vector<String>();
    private int[] intValues;
    private float[] floatValues;
    private float[] outputs;
    private String[] outputNames;

    private TensorFlowInferenceInterface inferenceInterface;

    private boolean runStats = false;

    private TensorFlowImageClassifier() {
    }

    /**
     * Initializes a native TensorFlow session for classifying images.
     *
     * @param assetManager  The asset manager to be used to load assets.
     * @param modelFilename The filepath of the model GraphDef protocol buffer.
     * @param labelFilename The filepath of label file for classes.
     * @param inputSize     The input size. A square image of inputSize x inputSize is assumed.
     * @param imageMean     The assumed mean of the image values.
     * @param imageStd      The assumed std of the image values.
     * @param inputName     The label of the image input node.
     * @param outputName    The label of the output node.
     * @throws IOException
     */
    public static Classifier create(
            AssetManager assetManager,
            String modelFilename,
            String labelFilename,
            int inputSize,
            int imageMean,
            float imageStd,
            String inputName,
            String outputName)
            throws IOException {
        TensorFlowImageClassifier c = new TensorFlowImageClassifier();
        c.inputName = inputName;
        c.outputName = outputName;

        // Read the label names into memory.
        // TODO(andrewharp): make this handle non-assets.
        String actualFilename = labelFilename.split("file:///android_asset/")[1];
        Log.i(TAG, "Reading labels from: " + actualFilename);
        BufferedReader br = null;
        br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
        String line;
        while ((line = br.readLine()) != null) {
            c.labels.add(line);
        }
        br.close();

        c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
        // The shape of the output is [N, NUM_CLASSES], where N is the batch size.
        int numClasses =
                (int) c.inferenceInterface.graph().operation(outputName).output(0).shape().size(1);
        Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);

        // Ideally, inputSize could have been retrieved from the shape of the input operation.  Alas,
        // the placeholder node for input in the graphdef typically used does not specify a shape, so it
        // must be passed in as a parameter.
        c.inputSize = inputSize;
        c.imageMean = imageMean;
        c.imageStd = imageStd;

        // Pre-allocate buffers.
        c.outputNames = new String[]{outputName};
        c.intValues = new int[inputSize * inputSize];
        c.floatValues = new float[inputSize * inputSize * 3];
        c.outputs = new float[numClasses];


        /*if(c.inferenceInterface != null && c.inferenceInterface.graph() != null && c.inferenceInterface.graph().operations()!=null){
            Iterator<Operation> operations = c.inferenceInterface.graph().operations();
            while(operations.hasNext()){
                Log.e("operation : ",""+operations.next().name());
            }
        }*/


        return c;
    }

    @Override
    public List<Recognition> recognizeImage(final Bitmap bitmap) {
        // Log this method so that it can be analyzed with systrace.
        TraceCompat.beginSection("recognizeImage");

        TraceCompat.beginSection("preprocessBitmap");
        // Preprocess the image data from 0-255 int to normalized float based
        // on the provided parameters.
        bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
        for (int i = 0; i < intValues.length; ++i) {
            final int val = intValues[i];
            floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd;
            floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd;
            floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd;
        }
        TraceCompat.endSection();

        // Copy the input data into TensorFlow.
        TraceCompat.beginSection("feed");
        inferenceInterface.feed(
                inputName, floatValues, new long[]{1, inputSize, inputSize, 3});
        TraceCompat.endSection();

        // Run the inference call.
        TraceCompat.beginSection("run");
        inferenceInterface.run(outputNames, runStats);
        TraceCompat.endSection();

        // Copy the output Tensor back into the output array.
        TraceCompat.beginSection("fetch");
        inferenceInterface.fetch(outputName, outputs);
        TraceCompat.endSection();

        // Find the best classifications.
        PriorityQueue<Recognition> pq =
                new PriorityQueue<Recognition>(
                        3,
                        new Comparator<Recognition>() {
                            @Override
                            public int compare(Recognition lhs, Recognition rhs) {
                                // Intentionally reversed to put high confidence at the head of the queue.
                                return Float.compare(rhs.getConfidence(), lhs.getConfidence());
                            }
                        });
        for (int i = 0; i < outputs.length; ++i) {
            if (outputs[i] > THRESHOLD) {
                pq.add(
                        new Recognition(
                                "" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null));
            }
        }
        final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
        int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
        for (int i = 0; i < recognitionsSize; ++i) {
            recognitions.add(pq.poll());
        }
        TraceCompat.endSection(); // "recognizeImage"
        return recognitions;
    }

    @Override
    public void enableStatLogging(boolean debug) {
        runStats = debug;
    }

    @Override
    public String getStatString() {
        return inferenceInterface.getStatString();
    }

    @Override
    public void close() {
        inferenceInterface.close();
    }
}

5. 在識別的Activity中

定義一些變數:

    private static final int INPUT_SIZE = 299;
    private static final int IMAGE_MEAN = 128;
    private static final float IMAGE_STD = 128;

    private static final String INPUT_NAME = "Mul";
    private static final String OUTPUT_NAME = "final_result";
    private static final String MODEL_FILE = "file:///android_asset/output.pb";
    private static final String LABEL_FILE ="file:///android_asset/output_labels.txt";

新增初始化 tensorflow 方法:

private void initTensorFlowAndLoadModel() {
        executor.execute(new Runnable() {
            @Override
            public void run() {
                try {
                    classifier = TensorFlowImageClassifier.create(
                            getAssets(),
                            MODEL_FILE,
                            LABEL_FILE,
                            INPUT_SIZE,
                            IMAGE_MEAN,
                            IMAGE_STD,
                            INPUT_NAME,
                            OUTPUT_NAME);
                } catch (final Exception e) {
                    throw new RuntimeException("Error initializing TensorFlow!", e);
                }
            }
        });
    }

這裡的圖片官方說法是使用299 * 299的,其他規格大小試了幾個都有問題,有的大了,有的提示不是2048的倍數,總之不想一直糾結,可以將圖片裁剪一下,一句話程式碼:

rightBitmap = Bitmap.createScaledBitmap(rightBitmap, 299, 299, true);

開始識別,直接呼叫即可:

final List<Classifier.Recognition> results = classifier.recognizeImage(rightBitmap);

返回的 results 是一個List集合,存放有預測物體的名稱,以及預測的準確率

可以發現比之前的opencv 準確多了。

Android程式截圖如下,與之前的Android專案類似,只是替換了識別地方的程式碼:

    

最後附上Android原始碼的下載地址(由於專案過大,因此不含tensorflow的模型):

https://download.csdn.net/download/qq_27063119/10346591