端到端車牌/驗證碼識別(tensorflow版)——(2)
阿新 • • 發佈:2018-12-13
端到端車牌識別(2)
二 、CNN方法
4. 模型訓練
先附上程式碼train.py:
""" Created on Tue Sep 5 15:37:26 2017 @author: llc """ #%% import os import numpy as np import tensorflow as tf from input_data import OCRIter import model #from genplate import * import time import datetime img_w = 272 img_h = 72 num_label=7 batch_size = 8 count =30000 learning_rate = 0.0001 #預設引數[N,H,W,C] image_holder = tf.placeholder(tf.float32,[batch_size,img_h,img_w,3]) label_holder = tf.placeholder(tf.int32,[batch_size,7]) keep_prob = tf.placeholder(tf.float32) logs_train_dir = '/home/llc/TF_test/Chinese_plate_recognition/Plate_recognition/train_logs_50000/' def get_batch(): data_batch = OCRIter(batch_size,img_h,img_w) image_batch,label_batch = data_batch.iter() image_batch1 = np.array(image_batch) label_batch1 = np.array(label_batch) return image_batch1,label_batch1 train_logits1,train_logits2,train_logits3,train_logits4,train_logits5,train_logits6,train_logits7= model.inference(image_holder,keep_prob) train_loss1,train_loss2,train_loss3,train_loss4,train_loss5,train_loss6,train_loss7 = model.losses(train_logits1,train_logits2,train_logits3,train_logits4,train_logits5,train_logits6,train_logits7,label_holder) train_op1,train_op2,train_op3,train_op4,train_op5,train_op6,train_op7 = model.trainning(train_loss1,train_loss2,train_loss3,train_loss4,train_loss5,train_loss6,train_loss7,learning_rate) train_acc = model.evaluation(train_logits1,train_logits2,train_logits3,train_logits4,train_logits5,train_logits6,train_logits7,label_holder) input_image=tf.summary.image('input',image_holder) #tf.summary.histogram('label',label_holder) #label的histogram,測試訓練程式碼時用,參考:http://geek.csdn.net/news/detail/197155 summary_op = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES)) #sess = tf.Session(config=tf.ConfigProto(log_device_placement=True)) #執行日誌 sess = tf.Session() train_writer = tf.summary.FileWriter(logs_train_dir,sess.graph) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) start_time1 = time.time() for step in range(count): x_batch,y_batch = get_batch() start_time2 = time.time() time_str = datetime.datetime.now().isoformat() feed_dict = {image_holder:x_batch,label_holder:y_batch,keep_prob:0.5} _,_,_,_,_,_,_,tra_loss1,tra_loss2,tra_loss3,tra_loss4,tra_loss5,tra_loss6,tra_loss7,acc,summary_str= sess.run([train_op1,train_op2,train_op3,train_op4,train_op5,train_op6,train_op7,train_loss1,train_loss2,train_loss3,train_loss4,train_loss5,train_loss6,train_loss7,train_acc,summary_op],feed_dict) train_writer.add_summary(summary_str,step) duration = time.time()-start_time2 tra_all_loss =tra_loss1+tra_loss2+tra_loss3+tra_loss4+tra_loss5+tra_loss6+tra_loss7 #print(y_batch) #僅測試程式碼訓練實際樣本與標籤是否一致 if step % 10== 0: sec_per_batch = float(duration) print('%s : Step %d,train_loss = %.2f,acc= %.2f,sec/batch=%.3f' %(time_str,step,tra_all_loss,acc,sec_per_batch) if step % 10000==0 or (step+1) == count: checkpoint_path = os.path.join(logs_train_dir,'model.ckpt') saver = tf.train.Saver() saver.save(sess,checkpoint_path,global_step=step) sess.close() print(time.time()-start_time1)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
這部分沒多大可講的,基本是採用常規訓練方法,資料的讀入採用Placeholder,每次產生一個batch就匯入訓練一個batch.
5. 模型驗證
(1) 測試單張圖:
import tensorflow as tf import numpy as np import os from PIL import Image import cv2 import matplotlib.pyplot as plt import model index = {"京": 0, "滬": 1, "津": 2, "渝": 3, "冀": 4, "晉": 5, "蒙": 6, "遼": 7, "吉": 8, "黑": 9, "蘇": 10, "浙": 11, "皖": 12, "閩": 13, "贛": 14, "魯": 15, "豫": 16, "鄂": 17, "湘": 18, "粵": 19, "桂": 20, "瓊": 21, "川": 22, "貴": 23, "雲": 24, "藏": 25, "陝": 26, "甘": 27, "青": 28, "寧": 29, "新": 30, "0": 31, "1": 32, "2": 33, "3": 34, "4": 35, "5": 36, "6": 37, "7": 38, "8": 39, "9": 40, "A": 41, "B": 42, "C": 43, "D": 44, "E": 45, "F": 46, "G": 47, "H": 48, "J": 49, "K": 50, "L": 51, "M": 52, "N": 53, "P": 54, "Q": 55, "R": 56, "S": 57, "T": 58, "U": 59, "V": 60, "W": 61, "X": 62, "Y": 63, "Z": 64}; chars = ["京", "滬", "津", "渝", "冀", "晉", "蒙", "遼", "吉", "黑", "蘇", "浙", "皖", "閩", "贛", "魯", "豫", "鄂", "湘", "粵", "桂", "瓊", "川", "貴", "雲", "藏", "陝", "甘", "青", "寧", "新", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "J", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z" ]; ''' Test one image against the saved models and parameters ''' def get_one_image(test): ''' Randomly pick one image from training data Return: ndarry ''' n = len(test) ind =np.random.randint(0,n) img_dir = test[ind] image_show = Image.open(img_dir) plt.imshow(image_show) #image = image.resize([120,30]) image = cv2.imread(img_dir) img = np.multiply(image,1/255.0) #image = np.array(img) #image = img.transpose(1,0,2) image = np.array([img]) print(image.shape) return image batch_size = 1 x = tf.placeholder(tf.float32,[batch_size,72,272,3]) keep_prob =tf.placeholder(tf.float32) test_dir = '/home/llc/TF_test/Chinese_plate_recognition/Plate_recognition/plate/' test_image = [] for file in os.listdir(test_dir): test_image.append(test_dir + file) test_image = list(test_image) image_array = get_one_image(test_image) #logit = model.inference(x,keep_prob) logit1,logit2,logit3,logit4,logit5,logit6,logit7 = model.inference(x,keep_prob) #logit1 = tf.nn.softmax(logit1) #logit2 = tf.nn.softmax(logit2) #logit3 = tf.nn.softmax(logit3) #logit4 = tf.nn.softmax(logit4) #logit5 = tf.nn.softmax(logit5) #logit6 = tf.nn.softmax(logit6) #logit7 = tf.nn.softmax(logit7) logs_train_dir = '/home/llc/TF_test/Chinese_plate_recognition/Plate_recognition/train_logs_50000/' saver = tf.train.Saver() with tf.Session() as sess: print ("Reading checkpoint...") ckpt = tf.train.get_checkpoint_state(logs_train_dir) if ckpt and ckpt.model_checkpoint_path: global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] saver.restore(sess, ckpt.model_checkpoint_path) print('Loading success, global_step is %s' % global_step) else: print('No checkpoint file found') pre1,pre2,pre3,pre4,pre5,pre6,pre7 = sess.run([logit1,logit2,logit3,logit4,logit5,logit6,logit7], feed_dict={x: image_array,keep_prob:1.0}) prediction = np.reshape(np.array([pre1,pre2,pre3,pre4,pre5,pre6,pre7]),[-1,65]) #prediction = np.array([[pre1],[pre2],[pre3],[pre4],[pre5],[pre6],[pre7]]) #print(prediction) max_index = np.argmax(prediction,axis=1) print(max_index) line = '' for i in range(prediction.shape[0]): if i == 0: result = np.argmax(prediction[i][0:31]) if i == 1: result = np.argmax(prediction[i][41:65])+41 if i > 1: result = np.argmax(prediction[i][31:65])+31 line += chars[result]+" " print ('predicted: ' + line)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
利用genplate.py生成車牌圖片並儲存,然後利用cv2.imread讀取圖片,tf.placeholder讀入資料進行測試(注意圖片儲存與讀取方式要一致)。
(2) 測試多張圖 此部分利用genplate.py產生大量樣本測試集,採用tf.train.slice_input_producera方式讀取樣本集,並預測出所有測試的樣本,並給出測試集樣本識別準確率,以及識別錯誤的影象編號。 測試32張圖結果:
識別準確率為:0.938(測試僅32張無意義) 識別錯誤圖片編號: 14.jpg: 錯誤識別為:遼Z RD2DS (5與S,識別錯誤) 24.jpg : 錯誤識別為:閩X 7GW13(3與半遮擋的S)
測試500張結果:準確率0.822(錯誤89張) 錯誤型別: 最邊上A識別為T;
識別為:冀 D 1 0 0 Y Y
識別為:渝 J U W L 1 K
基本都是相似的字元或遮擋的識別錯誤,當然也有少量的看似清晰的識別錯誤。同時,模型的訓練可以優化,以上所有的結果都是迭代30000次,訓練30000×batch_size(8)=24萬張樣本的結果,且全部都是程式碼生成的訓練及測試樣本。因此,其應用到實際採集的真實車牌圖片上還未測試,估計效果較差,最好的訓練方法是結合實際採集的樣本一起訓練(實際樣本的測試後續更新,畢竟也有在模擬樣本上的過擬合可能)。
5. 參考
以上算是車牌識別深度學習的一個彙總吧。基於LSTM方法的後續更新,歡迎大家一起討論。