1. 程式人生 > >端到端車牌/驗證碼識別(tensorflow版)——(2)

端到端車牌/驗證碼識別(tensorflow版)——(2)

端到端車牌識別(2)

二 、CNN方法

4. 模型訓練

先附上程式碼train.py:

"""
Created on Tue Sep  5 15:37:26 2017

@author: llc
"""
#%%
import os
import numpy as np
import tensorflow as tf
from input_data import OCRIter
import model
#from genplate import *
import time
import datetime

img_w = 272
img_h = 72
num_label=7
batch_size = 8
count =30000
learning_rate = 0.0001

#預設引數[N,H,W,C]
image_holder = tf.placeholder(tf.float32,[batch_size,img_h,img_w,3])
label_holder = tf.placeholder(tf.int32,[batch_size,7])
keep_prob = tf.placeholder(tf.float32)

logs_train_dir = '/home/llc/TF_test/Chinese_plate_recognition/Plate_recognition/train_logs_50000/'

def get_batch():
    data_batch = OCRIter(batch_size,img_h,img_w)
    image_batch,label_batch = data_batch.iter()

    image_batch1 = np.array(image_batch)
    label_batch1 = np.array(label_batch)

    return image_batch1,label_batch1


train_logits1,train_logits2,train_logits3,train_logits4,train_logits5,train_logits6,train_logits7= model.inference(image_holder,keep_prob)

train_loss1,train_loss2,train_loss3,train_loss4,train_loss5,train_loss6,train_loss7 = model.losses(train_logits1,train_logits2,train_logits3,train_logits4,train_logits5,train_logits6,train_logits7,label_holder) 
train_op1,train_op2,train_op3,train_op4,train_op5,train_op6,train_op7 = model.trainning(train_loss1,train_loss2,train_loss3,train_loss4,train_loss5,train_loss6,train_loss7,learning_rate)

train_acc = model.evaluation(train_logits1,train_logits2,train_logits3,train_logits4,train_logits5,train_logits6,train_logits7,label_holder)

input_image=tf.summary.image('input',image_holder)
#tf.summary.histogram('label',label_holder) #label的histogram,測試訓練程式碼時用,參考:http://geek.csdn.net/news/detail/197155

summary_op = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES))
#sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))  #執行日誌
sess = tf.Session() 
train_writer = tf.summary.FileWriter(logs_train_dir,sess.graph)
saver = tf.train.Saver()

sess.run(tf.global_variables_initializer())

start_time1 = time.time()    

for step in range(count): 

    x_batch,y_batch = get_batch()

    start_time2 = time.time()   
    time_str = datetime.datetime.now().isoformat()

    feed_dict = {image_holder:x_batch,label_holder:y_batch,keep_prob:0.5}
    _,_,_,_,_,_,_,tra_loss1,tra_loss2,tra_loss3,tra_loss4,tra_loss5,tra_loss6,tra_loss7,acc,summary_str= sess.run([train_op1,train_op2,train_op3,train_op4,train_op5,train_op6,train_op7,train_loss1,train_loss2,train_loss3,train_loss4,train_loss5,train_loss6,train_loss7,train_acc,summary_op],feed_dict)
    train_writer.add_summary(summary_str,step)
    duration = time.time()-start_time2
    tra_all_loss =tra_loss1+tra_loss2+tra_loss3+tra_loss4+tra_loss5+tra_loss6+tra_loss7

    #print(y_batch)  #僅測試程式碼訓練實際樣本與標籤是否一致

    if step % 10== 0:
        sec_per_batch = float(duration)
        print('%s : Step %d,train_loss = %.2f,acc= %.2f,sec/batch=%.3f' %(time_str,step,tra_all_loss,acc,sec_per_batch)

    if step % 10000==0 or (step+1) == count:
        checkpoint_path = os.path.join(logs_train_dir,'model.ckpt')
        saver = tf.train.Saver()
        saver.save(sess,checkpoint_path,global_step=step)
sess.close()       
print(time.time()-start_time1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84

這部分沒多大可講的,基本是採用常規訓練方法,資料的讀入採用Placeholder,每次產生一個batch就匯入訓練一個batch.

5. 模型驗證

(1) 測試單張圖:

import tensorflow as tf
import numpy as np
import os
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import model
index = {"京": 0, "滬": 1, "津": 2, "渝": 3, "冀": 4, "晉": 5, "蒙": 6, "遼": 7, "吉": 8, "黑": 9, "蘇": 10, "浙": 11, "皖": 12,
         "閩": 13, "贛": 14, "魯": 15, "豫": 16, "鄂": 17, "湘": 18, "粵": 19, "桂": 20, "瓊": 21, "川": 22, "貴": 23, "雲": 24,
         "藏": 25, "陝": 26, "甘": 27, "青": 28, "寧": 29, "新": 30, "0": 31, "1": 32, "2": 33, "3": 34, "4": 35, "5": 36,
         "6": 37, "7": 38, "8": 39, "9": 40, "A": 41, "B": 42, "C": 43, "D": 44, "E": 45, "F": 46, "G": 47, "H": 48,
         "J": 49, "K": 50, "L": 51, "M": 52, "N": 53, "P": 54, "Q": 55, "R": 56, "S": 57, "T": 58, "U": 59, "V": 60,
         "W": 61, "X": 62, "Y": 63, "Z": 64};

chars = ["京", "滬", "津", "渝", "冀", "晉", "蒙", "遼", "吉", "黑", "蘇", "浙", "皖", "閩", "贛", "魯", "豫", "鄂", "湘", "粵", "桂",
             "瓊", "川", "貴", "雲", "藏", "陝", "甘", "青", "寧", "新", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A",
             "B", "C", "D", "E", "F", "G", "H", "J", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "U", "V", "W", "X",
             "Y", "Z"
             ];
'''
Test one image against the saved models and parameters
'''

def get_one_image(test):
    '''
    Randomly pick one image from training data
    Return: ndarry
    '''
    n = len(test)
    ind =np.random.randint(0,n)
    img_dir = test[ind]

    image_show = Image.open(img_dir)
    plt.imshow(image_show)
    #image = image.resize([120,30])
    image = cv2.imread(img_dir)
    img = np.multiply(image,1/255.0)
    #image = np.array(img)
    #image = img.transpose(1,0,2)
    image = np.array([img])
    print(image.shape)

    return image

batch_size = 1
x = tf.placeholder(tf.float32,[batch_size,72,272,3])
keep_prob =tf.placeholder(tf.float32)

test_dir = '/home/llc/TF_test/Chinese_plate_recognition/Plate_recognition/plate/'
test_image = []
for file in os.listdir(test_dir):
    test_image.append(test_dir + file)
test_image = list(test_image)

image_array = get_one_image(test_image)

#logit = model.inference(x,keep_prob)
logit1,logit2,logit3,logit4,logit5,logit6,logit7 = model.inference(x,keep_prob)

#logit1 = tf.nn.softmax(logit1)
#logit2 = tf.nn.softmax(logit2)
#logit3 = tf.nn.softmax(logit3)
#logit4 = tf.nn.softmax(logit4)
#logit5 = tf.nn.softmax(logit5)
#logit6 = tf.nn.softmax(logit6)
#logit7 = tf.nn.softmax(logit7)

logs_train_dir = '/home/llc/TF_test/Chinese_plate_recognition/Plate_recognition/train_logs_50000/'

saver = tf.train.Saver()

with tf.Session() as sess:
    print ("Reading checkpoint...")
    ckpt = tf.train.get_checkpoint_state(logs_train_dir)
    if ckpt and ckpt.model_checkpoint_path:
        global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('Loading success, global_step is %s' % global_step)
    else:
        print('No checkpoint file found')

    pre1,pre2,pre3,pre4,pre5,pre6,pre7 = sess.run([logit1,logit2,logit3,logit4,logit5,logit6,logit7], feed_dict={x: image_array,keep_prob:1.0})
    prediction = np.reshape(np.array([pre1,pre2,pre3,pre4,pre5,pre6,pre7]),[-1,65])
    #prediction = np.array([[pre1],[pre2],[pre3],[pre4],[pre5],[pre6],[pre7]])
    #print(prediction)

    max_index = np.argmax(prediction,axis=1)
    print(max_index)
    line = ''
    for i in range(prediction.shape[0]):
        if i == 0:
            result = np.argmax(prediction[i][0:31])
        if i == 1:
            result = np.argmax(prediction[i][41:65])+41
        if i > 1:
            result = np.argmax(prediction[i][31:65])+31

        line += chars[result]+" "
    print ('predicted: ' + line)  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99

利用genplate.py生成車牌圖片並儲存,然後利用cv2.imread讀取圖片,tf.placeholder讀入資料進行測試(注意圖片儲存與讀取方式要一致)。

(2) 測試多張圖  此部分利用genplate.py產生大量樣本測試集,採用tf.train.slice_input_producera方式讀取樣本集,並預測出所有測試的樣本,並給出測試集樣本識別準確率,以及識別錯誤的影象編號。  測試32張圖結果: 這裡寫圖片描述

識別準確率為:0.938(測試僅32張無意義)  識別錯誤圖片編號:  14.jpg:這裡寫圖片描述  錯誤識別為:遼Z RD2DS (5與S,識別錯誤)  24.jpg :這裡寫圖片描述  錯誤識別為:閩X 7GW13(3與半遮擋的S)

測試500張結果:準確率0.822(錯誤89張) 這裡寫圖片描述  錯誤型別: 這裡寫圖片描述  最邊上A識別為T;

這裡寫圖片描述  識別為:冀 D 1 0 0 Y Y

這裡寫圖片描述  識別為:渝 J U W L 1 K

基本都是相似的字元或遮擋的識別錯誤,當然也有少量的看似清晰的識別錯誤。同時,模型的訓練可以優化,以上所有的結果都是迭代30000次,訓練30000×batch_size(8)=24萬張樣本的結果,且全部都是程式碼生成的訓練及測試樣本。因此,其應用到實際採集的真實車牌圖片上還未測試,估計效果較差,最好的訓練方法是結合實際採集的樣本一起訓練(實際樣本的測試後續更新,畢竟也有在模擬樣本上的過擬合可能)。

5. 參考

以上算是車牌識別深度學習的一個彙總吧。基於LSTM方法的後續更新,歡迎大家一起討論。