2.CNN圖片多標籤分類(基於TensorFlow實現驗證碼識別OCR)
上一篇實現了圖片CNN單標籤分類(貓狗圖片分類任務)
地址: ofollow,noindex">juejin.im/post/5c0739…
預告:下一篇用LSTM+CTC實現不定長文字的OCR,本質上是一種不固定標籤個數的多標籤分類問題
本文所用到的10w驗證碼資料集百度網盤下載地址(也可使用下文程式碼自行生成):
利用本文程式碼訓練並生成的模型(對應專案中的model資料夾):
專案簡介:
(需要預先安裝pip install captcha==0.1.1,pip install opencv-python,pip install flask, pip install tensorflow/pip install tensorflow-gpu) 本文采用CNN實現4位定長驗證碼圖片OCR(生成的驗證碼固定由隨機的4位大寫字母組成),本質上是一張圖片多個標籤的分類問題(資料如下圖所示)

整體訓練邏輯:
1,將影象傳入到CNN中提取特徵
2,將特徵圖拉伸輸入到FC layer中得出分類預測向量
3,通過sigmoid交叉熵函式對預測向量和標籤向量進行訓練,得出最終模型(注意:多標籤分類任務採用sigmoid,單標籤分類採用softmax)
整體預測邏輯:
1,將影象傳入到CNN(VGG16)中提取特徵
2,將特徵圖拉伸輸入到FC layer中得出分類預測向量
3,將預測向量做sigmoid操作,由於驗證碼固定是4位,所以將向量切分成4條,從每條中找到最大值,並對映到對應的字母上
製作成web服務:
利用flask框架將整個專案啟動成web服務,使得專案支援http方式呼叫 啟動服務後呼叫以下地址測試
http://127.0.0.1:5050/captchaOcr?img_path=./dataset/test/0_HZDZ.png
http://127.0.0.1:5050/captchaOcr?img_path=./dataset/test/1_CKAN.png
後續優化邏輯:
提取特徵部分的CNN可以用RNN取代
本方案只能OCR固定長度文字,後續採用LSTM+CTC的方式來OCR非定長文字
執行命令:
自行生成驗證碼訓練寄(本文生成了10w張,修改self.im_total_num變數): pythonCnnOcr.py create_dataset
對資料集進行訓練:pythonCnnOcr.py train
對新的圖片進行測試:pythonCnnOcr.py test
啟動成http服務:pythonCnnOcr.py start
專案目錄結構:

訓練過程:







整體程式碼如下:
# coding:utf-8 from captcha.image import ImageCaptcha import numpy as np import cv2 import tensorflow as tf import random, os, sys from flask import request from flask import Flask import json app = Flask(__name__) class CnnOcr: def __init__(self): self.epoch_max = 6# 最大迭代epoch次數 self.batch_size = 64# 訓練時每個批次參與訓練的影象數目,視訊記憶體不足的可以調小 self.lr = 1e-3# 初始學習率 self.save_epoch = 1# 每相隔多少個epoch儲存一次模型 self.im_width = 128 self.im_height = 64 self.im_total_num = 100000# 總共生成的驗證碼圖片數量 self.train_max_num = self.im_total_num# 訓練時讀取的最大圖片數目 self.val_num = 50 * self.batch_size# 不能大於self.train_max_num做驗證集用 self.words_num = 4# 每張驗證碼圖片上的數字個數 self.words = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' self.label_num = self.words_num * len(self.words) self.keep_drop = tf.placeholder(tf.float32) self.x = None self.y = None def captchaOcr(self, img_path): """ 驗證碼識別 :param img_path: :return: """ im = cv2.imread(img_path) im = cv2.resize(im, (self.im_width, self.im_height)) im = [im] im = np.array(im, dtype=np.float32) im -= 147 output = self.sess.run(self.max_idx_p, feed_dict={self.x: im, self.keep_drop: 1.}) ret = '' for i in output.tolist()[0]: ret = ret + self.words[int(i)] return ret def test(self, img_path): """ 測試介面 :param img_path: :return: """ self.x = tf.placeholder(tf.float32, [None, self.im_height, self.im_width, 3])# 輸入資料 self.pred = self.cnnNet() self.output = tf.nn.sigmoid(self.pred) self.predict = tf.reshape(self.pred, [-1, self.words_num, len(self.words)]) self.max_idx_p = tf.argmax(self.predict, 2) saver = tf.train.Saver() # tfconfig = tf.ConfigProto(allow_soft_placement=True) # tfconfig.gpu_options.per_process_gpu_memory_fraction = 0.3# 佔用視訊記憶體的比例 # self.ses = tf.Session(config=tfconfig) self.sess = tf.Session() self.sess.run(tf.global_variables_initializer())# 全域性tf變數初始化 # 載入w,b引數 saver.restore(self.sess, './model/CnnOcr-6') im = cv2.imread(img_path) im = cv2.resize(im, (self.im_width, self.im_height)) im = [im] im = np.array(im, dtype=np.float32) im -= 147 output = self.sess.run(self.max_idx_p, feed_dict={self.x: im, self.keep_drop: 1.}) ret = '' for i in output.tolist()[0]: ret = ret + self.words[int(i)] print(ret) def train(self): x_train_list, y_train_list, x_val_list, y_val_list = self.getTrainDataset() print('開始轉換tensor佇列') x_train_list_tensor = tf.convert_to_tensor(x_train_list, dtype=tf.string) y_train_list_tensor = tf.convert_to_tensor(y_train_list, dtype=tf.float32) x_val_list_tensor = tf.convert_to_tensor(x_val_list, dtype=tf.string) y_val_list_tensor = tf.convert_to_tensor(y_val_list, dtype=tf.float32) x_train_queue = tf.train.slice_input_producer(tensor_list=[x_train_list_tensor], shuffle=False) y_train_queue = tf.train.slice_input_producer(tensor_list=[y_train_list_tensor], shuffle=False) x_val_queue = tf.train.slice_input_producer(tensor_list=[x_val_list_tensor], shuffle=False) y_val_queue = tf.train.slice_input_producer(tensor_list=[y_val_list_tensor], shuffle=False) train_im, train_label = self.dataset_opt(x_train_queue, y_train_queue) train_batch = tf.train.batch(tensors=[train_im, train_label], batch_size=self.batch_size, num_threads=2) val_im, val_label = self.dataset_opt(x_val_queue, y_val_queue) val_batch = tf.train.batch(tensors=[val_im, val_label], batch_size=self.batch_size, num_threads=2) print('開啟訓練') self.learning_rate = tf.placeholder(dtype=tf.float32)# 動態學習率 self.x = tf.placeholder(tf.float32, [None, self.im_height, self.im_width, 3])# 訓練資料 self.y = tf.placeholder(tf.float32, [None, self.label_num])# 標籤 self.pred = self.cnnNet() self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.pred, labels=self.y)) self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss) self.predict = tf.reshape(self.pred, [-1, self.words_num, len(self.words)]) self.max_idx_p = tf.argmax(self.predict, 2) self.y_predict = tf.reshape(self.y, [-1, self.words_num, len(self.words)]) self.max_idx_l = tf.argmax(self.y_predict, 2) self.correct_pred = tf.equal(self.max_idx_p, self.max_idx_l) self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32)) with tf.Session() as self.sess: # 全域性tf變數初始化 self.sess.run(tf.global_variables_initializer()) coordinator = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=self.sess, coord=coordinator) # 模型儲存 saver = tf.train.Saver() batch_max = len(x_train_list) // self.batch_size total_step = 1 for epoch_num in range(self.epoch_max): lr = self.lr * (1 - (epoch_num/self.epoch_max) ** 2)# 動態學習率 for batch_num in range(batch_max): x_train_tmp, y_train_tmp = self.sess.run(train_batch) # print(x_train_tmp.shape, y_train_tmp.shape) # sys.exit() self.sess.run(self.optimizer, feed_dict={self.x: x_train_tmp, self.y: y_train_tmp, self.learning_rate: lr, self.keep_drop: .5}) # 輸出評價標準 if total_step % 50 == 0 or total_step == 1: print() print('epoch:%d/%d batch:%d/%d step:%d lr:%.10f' % ((epoch_num + 1), self.epoch_max, (batch_num + 1), batch_max, total_step, lr)) # 輸出訓練集評價 train_loss, train_acc = self.sess.run([self.loss, self.accuracy], feed_dict={self.x: x_train_tmp, self.y: y_train_tmp, self.keep_drop: 1.}) print('train_loss:%.10ftrain_acc:%.10f' % (np.mean(train_loss), train_acc)) # 輸出驗證集評價 val_loss_list, val_acc_list = [], [] for i in range(int(self.val_num/self.batch_size)): x_val_tmp, y_val_tmp = self.sess.run(val_batch) val_loss, val_acc = self.sess.run([self.loss, self.accuracy], feed_dict={self.x: x_val_tmp, self.y: y_val_tmp, self.keep_drop: 1.}) val_loss_list.append(np.mean(val_loss)) val_acc_list.append(np.mean(val_acc)) print('val_loss:%.10fval_acc:%.10f' % (np.mean(val_loss), np.mean(val_acc))) total_step += 1 # 儲存模型 if (epoch_num + 1) % self.save_epoch == 0: print('正在儲存模型:') saver.save(self.sess, './model/CnnOcr', global_step=(epoch_num + 1)) coordinator.request_stop() coordinator.join(threads) def cnnNet(self): """ cnn網路 :return: """ weight = { # 輸入 128*64*3 # 第一層 'wc1_1': tf.get_variable('wc1_1', [5, 5, 3, 32]),# 卷積 輸出:128*64*32 'wc1_2': tf.get_variable('wc1_2', [5, 5, 32, 32]),# 卷積 輸出:128*64*32 # 池化 輸出:64*32*32 # 第二層 'wc2_1': tf.get_variable('wc2_1', [5, 5, 32, 64]),# 卷積 輸出:64*32*64 'wc2_2': tf.get_variable('wc2_2', [5, 5, 64, 64]),# 卷積 輸出:64*32*64 # 池化 輸出:32*16*64 # 第三層 'wc3_1': tf.get_variable('wc3_1', [3, 3, 64, 64]),# 卷積 輸出:32*16*256 'wc3_2': tf.get_variable('wc3_2', [3, 3, 64, 64]),# 卷積 輸出:32*16*256 # 池化 輸出:16*8*256 # 第四層 'wc4_1': tf.get_variable('wc4_1', [3, 3, 64, 64]),# 卷積 輸出:16*8*64 'wc4_2': tf.get_variable('wc4_2', [3, 3, 64, 64]),# 卷積 輸出:16*8*64 # 池化 輸出:8*4*64 # 全連結第一層 'wfc_1': tf.get_variable('wfc_1', [8*4*64, 2048]), # 全連結第二層 'wfc_2': tf.get_variable('wfc_2', [2048, 2048]), # 全連結第三層 'wfc_3': tf.get_variable('wfc_3', [2048, self.label_num]), } biase = { # 第一層 'bc1_1': tf.get_variable('bc1_1', [32]), 'bc1_2': tf.get_variable('bc1_2', [32]), # 第二層 'bc2_1': tf.get_variable('bc2_1', [64]), 'bc2_2': tf.get_variable('bc2_2', [64]), # 第三層 'bc3_1': tf.get_variable('bc3_1', [64]), 'bc3_2': tf.get_variable('bc3_2', [64]), # 第四層 'bc4_1': tf.get_variable('bc4_1', [64]), 'bc4_2': tf.get_variable('bc4_2', [64]), # 全連結第一層 'bfc_1': tf.get_variable('bfc_1', [2048]), # 全連結第二層 'bfc_2': tf.get_variable('bfc_2', [2048]), # 全連結第三層 'bfc_3': tf.get_variable('bfc_3', [self.label_num]), } # 第一層 net = tf.nn.conv2d(self.x, weight['wc1_1'], [1, 1, 1, 1], 'SAME')# 卷積 net = tf.nn.bias_add(net, biase['bc1_1']) net = tf.nn.relu(net)# 加b 然後 啟用 print('conv1', net) net = tf.nn.max_pool(net, [1, 2, 2, 1], [1, 2, 2, 1], padding='VALID')# 池化 print('pool1', net) # 第二層 net = tf.nn.conv2d(net, weight['wc2_1'], [1, 1, 1, 1], padding='SAME')# 卷積 net = tf.nn.bias_add(net, biase['bc2_1']) net = tf.nn.relu(net)# 加b 然後 啟用 print('conv2', net) net = tf.nn.max_pool(net, [1, 2, 2, 1], [1, 2, 2, 1], padding='VALID')# 池化 print('pool2', net) # 第三層 net = tf.nn.conv2d(net, weight['wc3_1'], [1, 1, 1, 1], padding='SAME')# 卷積 net = tf.nn.bias_add(net, biase['bc3_1']) net = tf.nn.relu(net)# 加b 然後 啟用 print('conv3', net) net = tf.nn.max_pool(net, [1, 2, 2, 1], [1, 2, 2, 1], padding='VALID')# 池化 print('pool3', net) # 第四層 net = tf.nn.conv2d(net, weight['wc4_1'], [1, 1, 1, 1], padding='SAME')# 卷積 net = tf.nn.bias_add(net, biase['bc4_1']) net = tf.nn.relu(net)# 加b 然後 啟用 print('conv4', net) net = tf.nn.max_pool(net, [1, 2, 2, 1], [1, 2, 2, 1], padding='VALID')# 池化 print('pool4', net) # 拉伸flatten,把多個圖片同時分別拉伸成一條向量 net = tf.reshape(net, shape=[-1, weight['wfc_1'].get_shape()[0]]) print('拉伸flatten', net) # 全連結層 # fc第一層 net = tf.matmul(net, weight['wfc_1']) + biase['bfc_1'] net = tf.nn.dropout(net, self.keep_drop) net = tf.nn.relu(net) print('fc第一層', net) # fc第二層 net = tf.matmul(net, weight['wfc_2']) + biase['bfc_2'] net = tf.nn.dropout(net, self.keep_drop) net = tf.nn.relu(net) print('fc第二層', net) # fc第三層 net = tf.matmul(net, weight['wfc_3']) + biase['bfc_3'] print('fc第三層', net) return net def getTrainDataset(self): """ 整理資料集,把影象resize為128*64*3,訓練集做成self.im_total_num*128*64*3,把label做成0,1向量形式 :return: """ train_data_list = os.listdir('./dataset/train/') print('共有%d張訓練圖片, 讀取%d張:' % (len(train_data_list), self.train_max_num)) random.shuffle(train_data_list)# 打亂順序 y_val_list, y_train_list = [], [] x_val_list = train_data_list[:self.val_num] for x_val in x_val_list: words_tmp = x_val.split('.')[0].split('_')[1] y_val_list.append([1 if _w == w else 0 for w in words_tmp for _w in self.words]) x_train_list = train_data_list[self.val_num:self.train_max_num] for x_train in x_train_list: words_tmp = x_train.split('.')[0].split('_')[1] y_train_list.append([1 if _w == w else 0 for w in words_tmp for _w in self.words]) return x_train_list, y_train_list, x_val_list, y_val_list def createCaptchaDataset(self): """ 生成訓練用圖片資料集 :return: """ image = ImageCaptcha(width=self.im_width, height=self.im_height, font_sizes=(56,)) for i in range(self.im_total_num): words_tmp = '' for j in range(self.words_num): words_tmp = words_tmp + random.choice(self.words) print(words_tmp, type(words_tmp)) im_path = './dataset/train/%d_%s.png' % (i, words_tmp) print(im_path) image.write(words_tmp, im_path) return True def dataset_opt(self, x_train_queue, y_train_queue): """ 處理圖片和標籤 :param queue: :return: """ queue = x_train_queue[0] contents = tf.read_file('./dataset/train/' + queue) im = tf.image.decode_jpeg(contents) im = tf.image.resize_images(images=im, size=[self.im_height, self.im_width]) im = tf.reshape(im, tf.stack([self.im_height, self.im_width, 3])) im -= 147# 去均值化 # im /= 255# 將畫素處理在0~1之間,加速收斂 # im -= 0.5# 將畫素處理在-0.5~0.5之間 return im, y_train_queue[0] if __name__ == '__main__': opt_type = sys.argv[1:][0] instance = CnnOcr() if opt_type == 'create_dataset': instance.createCaptchaDataset() elif opt_type == 'train': instance.train() elif opt_type == 'test': instance.test('./dataset/test/0_HZDZ.png') elif opt_type == 'start': # 將session持久化到記憶體中 instance.test('./dataset/test/0_HZDZ.png') # 啟動web服務 # http://127.0.0.1:5050/captchaOcr?img_path=./dataset/test/2_SYVD.png @app.route('/captchaOcr', methods=['GET']) def captchaOcr(): img_path = request.args.to_dict().get('img_path') print(img_path) ret = instance.captchaOcr(img_path) print(ret) return json.dumps({'img_path': img_path, 'ocr_ret': ret}) app.run(host='0.0.0.0', port=5050, debug=False) 複製程式碼