1. 程式人生 > >基於Inception v3進行多標籤訓練 修正了錯誤並進一步完善了程式碼

基於Inception v3進行多標籤訓練 修正了錯誤並進一步完善了程式碼

多標籤訓練只適合未修改inception v3網路的情形,不同於遷移學習。本文參考了基於Inception v3進行多標籤訓練  修正了錯誤並進一步完善了程式碼

資料集的準備,假設有3個類,每個類別差不多有50張圖,注意圖片的規模不能太少(一般一個類不小於25張圖),不然在驗證的時候會報除0錯誤。

先看一下目錄結構


images:需建立一個目錄(圖中為multi_image),存放所有的圖片, 不同於tf v1.1.0及之後官方的retrain,要求將多個類放在這個目錄

retrained_labels.txt:包含資料集提供的類別,具體看你的類別有啥

image_labels_dir:包含所有圖片名字+txt,比如 資料集有 圖boat.jpg,在此目錄下就會有boat.jpg.txt, 內容儲存圖片的標籤,一個圖片可能有多個標籤,為簡單測試,我使用getclass.sh來每張圖生成一個標籤

getclass.sh:將同一類別的不同圖片放在一個目錄下,將此檔案放在此目錄下執行,為每張圖片生成txt檔案,內容就是目錄的名字。之後將圖拷貝到 images/multi_image/,將**.jpg.txt檔案拷貝到 image_labels_dir

dir=$(cd "$(dirname "$0")";pwd)
basedir=$(basename ${dir})
echo "$basedir"

for name in $dir/*
do
  echo "$name"
  filename=$(basename ${name})
  echo "$filename"  
  echo "$basedir" >> ${filename}.txt
done
echo $0
rm -rf $(basename $0).txt
具體的retrain程式碼太多,就不貼出來了,見本人的 github

執行eval_retrain.sh指令碼,如果裡面的引數自己需要改動,根據情況設定即可

python retrain.py \
--bottleneck_dir=bottlenecks \
--how_many_training_steps 1000 \
--model_dir=model_dir \
--output_graph=retrained_graph.pb \
--output_labels=retrained_labels.txt \
--summaries_dir=retrain_logs \
--image_dir=images
執行完之後,會在此工程目錄下生成 retrained_graph.pb檔案和log日誌(可以在tensorbord下視覺化),使用label_image就可以測試了。注意,如果使用的是tf v1.1.0及之後的官方label_image.py需修改引數

input_layer = "Mul"
output_layer = "final_result"

專案中的label_image_v0.py是舊版,程式碼如下

import tensorflow as tf
import sys

# change this as you see fit
image_path = sys.argv[1]

# Read in the image_data
image_data = tf.gfile.FastGFile(image_path, 'rb').read()

# Loads label file, strips off carriage return
label_lines = [line.rstrip() for line
                   in tf.gfile.GFile("retrained_labels.txt")]

# Unpersists graph from file
with tf.gfile.FastGFile("retrained_graph.pb", 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')

with tf.Session() as sess:
    # Feed the image_data as input to the graph and get first prediction
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')

    predictions = sess.run(softmax_tensor, \
             {'DecodeJpeg/contents:0': image_data})

    # Sort to show labels of first prediction in order of confidence
    top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]

    for node_id in top_k:
        human_string = label_lines[node_id]
        score = predictions[0][node_id]
        print('%s (score = %.5f)' % (human_string, score))


    filename = "results.txt"
    with open(filename, 'a+') as f:
        f.write('\n**%s**\n' % (image_path))
        for node_id in top_k:
            human_string = label_lines[node_id]
            score = predictions[0][node_id]
            f.write('%s (score = %.5f)\n' % (human_string, score))
新版改動較大,程式碼如下:

# coding=utf-8
# 用於更改資料集重訓練模型
# ======================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

import numpy as np
import tensorflow as tf


def load_graph(model_file):
    graph = tf.Graph()
    graph_def = tf.GraphDef()

    with open(model_file, "rb") as f:
        graph_def.ParseFromString(f.read())
    with graph.as_default():
        tf.import_graph_def(graph_def)

    return graph


def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
                                input_mean=0, input_std=255):
    input_name = "file_reader"
    output_name = "normalized"
    file_reader = tf.read_file(file_name, input_name)
    if file_name.endswith(".png"):
        image_reader = tf.image.decode_png(file_reader, channels=3,
                                           name='png_reader')
    elif file_name.endswith(".gif"):
        image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
                                                      name='gif_reader'))
    elif file_name.endswith(".bmp"):
        image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
    else:
        image_reader = tf.image.decode_jpeg(file_reader, channels=3,
                                            name='jpeg_reader')
    float_caster = tf.cast(image_reader, tf.float32)
    dims_expander = tf.expand_dims(float_caster, 0);
    resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
    normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
    sess = tf.Session()
    result = sess.run(normalized)

    return result


def load_labels(label_file):
    label = []
    proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
    for l in proto_as_ascii_lines:
        label.append(l.rstrip())
    return label


if __name__ == "__main__":
    # file_name = "tensorflow/examples/label_image/data/grace_hopper.jpg"
    # model_file = \
    #    "tensorflow/examples/label_image/data/inception_v3_2016_08_28_frozen.pb"
    # label_file = "tensorflow/examples/label_image/data/imagenet_slim_labels.txt"
    file_name = 'test.jpg'
    model_file = 'retrained_graph.pb'
    label_file = 'retrained_labels.txt'
    input_height = 299
    input_width = 299
    input_mean = 0
    input_std = 255
    # -------必須修改下面的2個名稱
    input_layer = "Mul"
    output_layer = "final_result"

    parser = argparse.ArgumentParser()
    parser.add_argument("--image", help="image to be processed")
    parser.add_argument("--graph", help="graph/model to be executed")
    parser.add_argument("--labels", help="name of file containing labels")
    parser.add_argument("--input_height", type=int, help="input height")
    parser.add_argument("--input_width", type=int, help="input width")
    parser.add_argument("--input_mean", type=int, help="input mean")
    parser.add_argument("--input_std", type=int, help="input std")
 parser.add_argument("--input_layer", help="name of input layer")
    parser.add_argument("--output_layer", help="name of output layer")
    args = parser.parse_args()

    if args.graph:
        model_file = args.graph
    if args.image:
        file_name = args.image
    if args.labels:
        label_file = args.labels
    if args.input_height:
        input_height = args.input_height
    if args.input_width:
        input_width = args.input_width
    if args.input_mean:
        input_mean = args.input_mean
    if args.input_std:
        input_std = args.input_std
    if args.input_layer:
        input_layer = args.input_layer
    if args.output_layer:
        output_layer = args.output_layer

    graph = load_graph(model_file)
    t = read_tensor_from_image_file(file_name,
                                    input_height=input_height,
                                    input_width=input_width,
                                    input_mean=input_mean,
                                    input_std=input_std)

    input_name = "import/" + input_layer
    output_name = "import/" + output_layer
    input_operation = graph.get_operation_by_name(input_name)
    output_operation = graph.get_operation_by_name(output_name)

    with tf.Session(graph=graph) as sess:
        results = sess.run(output_operation.outputs[0],
                           {input_operation.outputs[0]: t})
    results = np.squeeze(results)

    top_k = results.argsort()[-5:][::-1]
    labels = load_labels(label_file)
    for i in top_k:
        print(labels[i], results[i])
給一張帆船的測試圖,分別執行二者程式,結果如下
[email protected]:~/PycharmProjects/multi-label$ python label_image_v1.py 
2018-01-25 12:49:10.944557: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
2018-01-25 12:49:11.320347: W tensorflow/core/framework/op_def_util.cc:343] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
junk 0.971204
boat 0.0287959
[email protected]:~/PycharmProjects/multi-label$ python label_image_v0.py test.jpg
2018-01-25 12:49:32.411879: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
2018-01-25 12:49:32.773317: W tensorflow/core/framework/op_def_util.cc:343] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
carrier (score = 0.98324)
boat (score = 0.01676)
可以發現,新版的label_image的測試結果更好,建議使用。

具體的細節見github

參考文獻

基於inception v3進行多標籤訓練

CRITICAL: tensorflow:Category has no images - validation

保證至少1張驗證集圖片的改良retrain