tensorflow移植到Android端,實現物體檢測自動拍照
一. 說明
1. tensorflow是什麼:
是谷歌基於DistBelief進行研發的第二代人工智慧學習系統。
2. 為什麼要使用tensorflow在手機端進行物體檢測:
前一篇部落格中講到,將opencv移植到android中,檢測到指定物體自動拍照,雖然說是功能確實可以實現,但是opencv畢竟比較落後了,識別的錯誤率還是很大的,於是便引入了tensorflow進行物體分類檢測,在很大程度上提高了識別率以及正確率。
本篇部落格主要講解一下,tensorflow是怎樣移植到手機端的,以及自定義的模型是怎樣訓練的。
二. 準備工作
1. 下載tensorflow專案(地址:https://github.com/tensorflow/tensorflow)
2. python環境(我是用的是python3.5的)
3. tensorflow安裝(有兩個版本,cpu版,gpu版,顯然gpu訓練模型時要快很多,這個視電腦配置而定吧,一般安裝anaconda,直接執行:conda install tensorflow安裝即可,我使用了tensorflow-gpu版本,需要nvidia顯示卡支援,命令:conda install tensorflow-gpu)
4. Android Studio (我使用了2.3.3版本)
5. 下載 libtensorflow_inference.so 以及 libandroid_tensorflow_inference_java.jar檔案,(這兩個檔案可以使用原始碼進行編譯生成),連結:https://pan.baidu.com/s/1tN_nNqfy6JC272J17VaWTg 密碼:boat
三. 訓練自定義的tensorflow模型
1. 準備資料集
使用tensorflow訓練模型,該分類的類別數必須大於等於2的(沒有背景這一類別),舉例:
如果需要進行識別人和狗兩種類別,那麼:需要準備圖片(只有狗在裡面的)100來張,放入dogs資料夾,同時準備圖片(只有人在裡面的)100來張,放入peoples資料夾,圖片越多訓練出來的模型越精確,每個類別100來張只能說勉強夠用,
將兩個類別的資料夾放置:
tensorflow_master/tensorflow/examples/image_retraining/data/train 中
data/train資料夾沒有的話,自行新建
2. 準備預訓練模型
訓練模型需要用到imagenet預訓練權重,4個檔案(classify_image_graph_def.pb,imagenet_2012_challenge_label_map_proto.pbtxt,imagenet_synset_to_human_label_map.txt,inception-2015-12-05.tgz),下載連結:連結:https://pan.baidu.com/s/1JlDbYy4NHD7qy3Or5lDtSg 密碼:i3jo
提前下載拷貝至 model資料夾下,沒有該資料夾請自行新建,否則會自動下載很慢的
3. 開始訓練
cd 進入tensorflow_master\tensorflow\examples\image_retraining資料夾:
執行命令:
python retrain.py --bottleneck_dir bottleneck --how_many_training_steps 4000 --model_dir model/ --output_graph output_graph.pb --output_labels output_labels.txt --image_dir data/train/
執行完畢會在tensorflow_master\tensorflow\examples\image_retraining資料夾下生成兩個檔案:
output_graph.pb 以及 output_labels.txt
4. 上一步驟中生成的模型不能直接放置到Android中,需要一步轉化:官方的解釋:
To use v3 Inception model, strip the DecodeJpeg Op from your retrained
// model first:
cd 進入tensorflow_master\tensorflow\python\tools資料夾,將上步中生成的 output_graph.pb 檔案複製到改目錄下,執行命令:
python strip_unused.py --input_graph=output_graph.pb --output_graph=output.pb --input_node_names="Mul" --output_node_names="final_result" --input_binary=true
即可在改目錄下生成 output.pb 檔案。
至此,模型訓練完畢。
四. 整合Android專案
1. 新建專案後,在\app\src\main目錄下 新建assets以及jniLibs兩個目錄,將之前生成的 output.pb 以及 output_labels.txt檔案拷貝至assets資料夾下
2. 在jniLibs資料夾下新建armeabi-v7a 資料夾,將 libtensorflow_inference.so 拷貝至 jniLibs\armeabi-v7a 資料夾下
3. 將libandroid_tensorflow_inference_java.jar 新增至專案中,不會的直接搜尋 Android Studio新增jar。
4. 新建一個類(Classifier.Java):
import android.graphics.Bitmap;
import android.graphics.RectF;
import java.util.List;
/**
* Created by amitshekhar on 06/03/17.
*/
/**
* Generic interface for interacting with different recognition engines.
*/
public interface Classifier {
/**
* An immutable result returned by a Classifier describing what was recognized.
*/
public class Recognition {
/**
* A unique identifier for what has been recognized. Specific to the class, not the instance of
* the object.
*/
private final String id;
/**
* Display name for the recognition.
*/
private final String title;
/**
* A sortable score for how good the recognition is relative to others. Higher should be better.
*/
private final Float confidence;
/**
* Optional location within the source image for the location of the recognized object.
*/
private RectF location;
public Recognition(
final String id, final String title, final Float confidence, final RectF location) {
this.id = id;
this.title = title;
this.confidence = confidence;
this.location = location;
}
public String getId() {
return id;
}
public String getTitle() {
return title;
}
public Float getConfidence() {
return confidence;
}
public RectF getLocation() {
return new RectF(location);
}
public void setLocation(RectF location) {
this.location = location;
}
@Override
public String toString() {
String resultString = "";
if (id != null) {
resultString += "[" + id + "] ";
}
if (title != null) {
resultString += title + " ";
}
if (confidence != null) {
resultString += String.format("(%.1f%%) ", confidence * 100.0f);
}
if (location != null) {
resultString += location + " ";
}
return resultString.trim();
}
}
List<Recognition> recognizeImage(Bitmap bitmap);
void enableStatLogging(final boolean debug);
String getStatString();
void close();
}
5. 新建識別實現類 (TensorFlowImageClassifier.Java)import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.support.v4.os.TraceCompat;
import android.util.Log;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Vector;
import www.demo04.com.util.tensorflow.Classifier;
/**
* Created by amitshekhar on 06/03/17.
*/
/**
* A classifier specialized to label images using TensorFlow.
*/
public class TensorFlowImageClassifier implements Classifier {
private static final String TAG = "ImageClassifier";
// Only return this many results with at least this confidence.
private static final int MAX_RESULTS = 2;
private static final float THRESHOLD = 0.1f;
// Config values.
private String inputName;
private String outputName;
private int inputSize;
private int imageMean;
private float imageStd;
// Pre-allocated buffers.
private Vector<String> labels = new Vector<String>();
private int[] intValues;
private float[] floatValues;
private float[] outputs;
private String[] outputNames;
private TensorFlowInferenceInterface inferenceInterface;
private boolean runStats = false;
private TensorFlowImageClassifier() {
}
/**
* Initializes a native TensorFlow session for classifying images.
*
* @param assetManager The asset manager to be used to load assets.
* @param modelFilename The filepath of the model GraphDef protocol buffer.
* @param labelFilename The filepath of label file for classes.
* @param inputSize The input size. A square image of inputSize x inputSize is assumed.
* @param imageMean The assumed mean of the image values.
* @param imageStd The assumed std of the image values.
* @param inputName The label of the image input node.
* @param outputName The label of the output node.
* @throws IOException
*/
public static Classifier create(
AssetManager assetManager,
String modelFilename,
String labelFilename,
int inputSize,
int imageMean,
float imageStd,
String inputName,
String outputName)
throws IOException {
TensorFlowImageClassifier c = new TensorFlowImageClassifier();
c.inputName = inputName;
c.outputName = outputName;
// Read the label names into memory.
// TODO(andrewharp): make this handle non-assets.
String actualFilename = labelFilename.split("file:///android_asset/")[1];
Log.i(TAG, "Reading labels from: " + actualFilename);
BufferedReader br = null;
br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
String line;
while ((line = br.readLine()) != null) {
c.labels.add(line);
}
br.close();
c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
// The shape of the output is [N, NUM_CLASSES], where N is the batch size.
int numClasses =
(int) c.inferenceInterface.graph().operation(outputName).output(0).shape().size(1);
Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);
// Ideally, inputSize could have been retrieved from the shape of the input operation. Alas,
// the placeholder node for input in the graphdef typically used does not specify a shape, so it
// must be passed in as a parameter.
c.inputSize = inputSize;
c.imageMean = imageMean;
c.imageStd = imageStd;
// Pre-allocate buffers.
c.outputNames = new String[]{outputName};
c.intValues = new int[inputSize * inputSize];
c.floatValues = new float[inputSize * inputSize * 3];
c.outputs = new float[numClasses];
/*if(c.inferenceInterface != null && c.inferenceInterface.graph() != null && c.inferenceInterface.graph().operations()!=null){
Iterator<Operation> operations = c.inferenceInterface.graph().operations();
while(operations.hasNext()){
Log.e("operation : ",""+operations.next().name());
}
}*/
return c;
}
@Override
public List<Recognition> recognizeImage(final Bitmap bitmap) {
// Log this method so that it can be analyzed with systrace.
TraceCompat.beginSection("recognizeImage");
TraceCompat.beginSection("preprocessBitmap");
// Preprocess the image data from 0-255 int to normalized float based
// on the provided parameters.
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
for (int i = 0; i < intValues.length; ++i) {
final int val = intValues[i];
floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd;
floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd;
floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd;
}
TraceCompat.endSection();
// Copy the input data into TensorFlow.
TraceCompat.beginSection("feed");
inferenceInterface.feed(
inputName, floatValues, new long[]{1, inputSize, inputSize, 3});
TraceCompat.endSection();
// Run the inference call.
TraceCompat.beginSection("run");
inferenceInterface.run(outputNames, runStats);
TraceCompat.endSection();
// Copy the output Tensor back into the output array.
TraceCompat.beginSection("fetch");
inferenceInterface.fetch(outputName, outputs);
TraceCompat.endSection();
// Find the best classifications.
PriorityQueue<Recognition> pq =
new PriorityQueue<Recognition>(
3,
new Comparator<Recognition>() {
@Override
public int compare(Recognition lhs, Recognition rhs) {
// Intentionally reversed to put high confidence at the head of the queue.
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
for (int i = 0; i < outputs.length; ++i) {
if (outputs[i] > THRESHOLD) {
pq.add(
new Recognition(
"" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null));
}
}
final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
for (int i = 0; i < recognitionsSize; ++i) {
recognitions.add(pq.poll());
}
TraceCompat.endSection(); // "recognizeImage"
return recognitions;
}
@Override
public void enableStatLogging(boolean debug) {
runStats = debug;
}
@Override
public String getStatString() {
return inferenceInterface.getStatString();
}
@Override
public void close() {
inferenceInterface.close();
}
}
5. 在識別的Activity中
定義一些變數:
private static final int INPUT_SIZE = 299;
private static final int IMAGE_MEAN = 128;
private static final float IMAGE_STD = 128;
private static final String INPUT_NAME = "Mul";
private static final String OUTPUT_NAME = "final_result";
private static final String MODEL_FILE = "file:///android_asset/output.pb";
private static final String LABEL_FILE ="file:///android_asset/output_labels.txt";
新增初始化 tensorflow 方法:
private void initTensorFlowAndLoadModel() {
executor.execute(new Runnable() {
@Override
public void run() {
try {
classifier = TensorFlowImageClassifier.create(
getAssets(),
MODEL_FILE,
LABEL_FILE,
INPUT_SIZE,
IMAGE_MEAN,
IMAGE_STD,
INPUT_NAME,
OUTPUT_NAME);
} catch (final Exception e) {
throw new RuntimeException("Error initializing TensorFlow!", e);
}
}
});
}
這裡的圖片官方說法是使用299 * 299的,其他規格大小試了幾個都有問題,有的大了,有的提示不是2048的倍數,總之不想一直糾結,可以將圖片裁剪一下,一句話程式碼:
rightBitmap = Bitmap.createScaledBitmap(rightBitmap, 299, 299, true);
開始識別,直接呼叫即可:
final List<Classifier.Recognition> results = classifier.recognizeImage(rightBitmap);
返回的 results 是一個List集合,存放有預測物體的名稱,以及預測的準確率
可以發現比之前的opencv 準確多了。
Android程式截圖如下,與之前的Android專案類似,只是替換了識別地方的程式碼:
最後附上Android原始碼的下載地址(由於專案過大,因此不含tensorflow的模型):
https://download.csdn.net/download/qq_27063119/10346591