tensorflow基礎學習:字元數字驗證碼寫入tfrecord檔案封裝成類
阿新 • • 發佈:2018-12-03
今天分享一下我寫的一個小小程式,基本可以滿足數字+字元型別字串寫入tfrecord檔案。還請多多指教!
簡單說明:這個是數字+字元4位驗證碼的tfrecord生成程式碼,5位,6位的可以自行修改一下,也就一點程式碼。我因為有點晚了就先不改了,大家加油啦。
- 先做些準備工作。
- 所有字元的資料集,用於將字元轉化為它的下標數字。
再存到tfrecord裡面。以便於後面讀取轉化為one-hot編碼使用。
import tensorflow as tf
import os
import random
import sys
from PIL import Image
import numpy as np
# 所有字元的資料集,用於將字元轉化為它的在列表中的下標數字
# 再存到tfrecord裡面。以便於後面讀取轉化為one-hot編碼使用
char_set = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k',
'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'A', 'B', 'C', 'D', 'E', 'F',
'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
- 在這裡寫成一個類,便於程式碼複用。
- 大家可以根據需求稍作修改使用。
注意點:路徑記得把反斜槓換了,如 F:\checkimages\,要換為F:/checkimages/,最後面的斜槓別少,F:/checkimages也是不可以的
class Make_Tf_Record(object):
# 將圖片資料轉化為tf檔案,打包成測試集和訓練集
def __init__(self, captcha_dir, tf_file_save_dir):
self.char_set = char_set
# 驗證碼的儲存路徑
self.captcha_dir = captcha_dir
# 生成的tf檔案路徑
self.tf_file_save_dir = tf_file_save_dir
- 判斷儲存tfrecord檔案的路徑裡面是否已經存在tfrecord檔案
- 在後面會呼叫,很簡單的幾句程式碼
def data_exist(self):
# 判斷 record 檔案是否存在
for split_name in ['train', 'test']:
output_filename = os.path.join(self.tf_file_save_dir, split_name + '.tfrecords')
if not tf.gfile.Exists(output_filename):
return False
return True
- 獲取所有驗證碼圖片的具體路徑
def get_all_captcha_filename(self, captcha_dir):
# captcha_dir驗證碼圖片所在的路徑
# 獲取所有驗證碼圖片的具體路徑
captcha_filenames = []
for filename in os.listdir(captcha_dir):
# 獲取檔案路徑
path = os.path.join(captcha_dir, filename)
captcha_filenames.append(path)
return captcha_filenames
- 為轉化為tf檔案做準備的工作,幾乎都是固定的寫法。
- 下面這個是為了將圖片的畫素值以bytes型別存進去,
- 也可以說是:列表形狀,字串格式。 如"[[123,123,1],[,23,4,534]]"。
- 這裡只是一個簡單說明,實際存進去的還是0和1組成的二進位制資料。
- 不然也不叫bytes。讀取的時候再解碼一下就好,解碼tensorflow都有可以呼叫的函式,不慌。
- int64_feature: 就是將驗證碼標籤的下標存進去
def bytes_feature(self, values):
# 用來存圖片畫素值
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
def int64_feature(self,values):
# 判斷values是否是列表或者元組,如果不是轉為列表
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
- 下面就是呼叫上面的2個函式,接收處理好的image資料和4個標籤值,序列化一下,返回一個物件。
- 返回值呼叫一下SerializeToString()就可以寫進去了
def image_to_tf_example(self, image_data, label0, label1, label2, label3):
# 這裡預設是4位的字元型別,可以自行修改
# 傳入圖片的資料和標籤,然後返回Example協議型別的資料
# 全部以字元型別存進去, 分開字元存是為了多工訓練
# print(label0, label1, label2, label3)
# 先獲取每個字元對應的下標
label0 = char_set.index(label0)
label1 = char_set.index(label1)
label2 = char_set.index(label2)
label3 = char_set.index(label3)
# return:返回一個Example物件,後來存進去的時候直接序列化一下 .SerializeToString()
# 這其實是一個類字典的格式,讀取的時候就會發現確實是這樣
return tf.train.Example(features=tf.train.Features(feature={
'image': self.bytes_feature(image_data),
'label0': self.int64_feature(label0),
'label1': self.int64_feature(label1),
'label2': self.int64_feature(label2),
'label3': self.int64_feature(label3),
}))
- 劃重點:關鍵一步,程式碼長一點點。
- 看註釋好理解
# 把資料轉為TFRecord格式
def _convert_dataset(self, split_name, filenames):
# 斷言,其實沒啥用,可以直接註釋
assert split_name in ['train', 'test']
with tf.Session() as sess:
# 定義tfrecord檔案的路徑+名字
output_filename = os.path.join(self.tf_file_save_dir, split_name + '.tfrecords')
# 開啟一個tf檔案寫入器,取名為tfrecord_writer
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
for i, filename in enumerate(filenames):
try:
sys.stdout.write('\r>> Converting image %d/%d' % (i + 1, len(filenames)))
sys.stdout.flush()
# 讀取圖片
image_data = Image.open(filename)
# 根據模型的結構resize
image_data = image_data.resize((224, 224))
# 灰度化,並轉化為一個儲存著畫素值的多維陣列
image_data = np.array(image_data.convert('L'))
# 將圖片轉化為bytes
image_data = image_data.tobytes()
# 獲取label,獲取路徑分割的陣列的最後一個也就是 abcd.jpg
labels = filename.split('/')[-1][0:4]
# 獲取前面4個標籤值
num_labels = []
for j in range(4):
num_labels.append(labels[j])
# 呼叫上面的函式,可以往回看看,生成protocol資料型別
example = self.image_to_tf_example(image_data, num_labels[0], num_labels[1], num_labels[2],
num_labels[3])
# 呼叫write方法,直接寫入一個圖片的檔案。
# SerializeToString在上面也提到啦,其實基本都是這麼寫,想多瞭解可以看一下函式介紹
tfrecord_writer.write(example.SerializeToString())
except IOError as e:
print('Could not read:', filename)
print('Error:', e)
print('Skip it\n')
sys.stdout.write('\n')
sys.stdout.flush()
- 最後的一個主函式,直接建立物件後呼叫這個函式就可以生成tf檔案
- 其實打亂的步驟可以去掉也是可以的,
- 原因:get_all_captcha_filename中使用的是os.listdir(),這個函式返回的檔名稱列表就是亂的
def start(self, test_num):
# 判斷tfrecord檔案是否存在
if self.data_exist():
print('tfcecord檔案已存在')
else:
# 獲得所有圖片
captcha_filenames = self.get_all_captcha_filename(self.captcha_dir)
# 把資料切分為訓練集和測試集,並打亂
# 隨機種子設定為0
random.seed(0)
random.shuffle(captcha_filenames)
training_filenames = captcha_filenames[test_num:]
testing_filenames = captcha_filenames[:test_num]
# 資料轉換
self._convert_dataset('train', training_filenames)
self._convert_dataset('test', testing_filenames)
print('生成tfcecord檔案')