1. 程式人生 > >tensorflow基礎學習:字元數字驗證碼寫入tfrecord檔案封裝成類

tensorflow基礎學習:字元數字驗證碼寫入tfrecord檔案封裝成類

今天分享一下我寫的一個小小程式,基本可以滿足數字+字元型別字串寫入tfrecord檔案。還請多多指教!

簡單說明:這個是數字+字元4位驗證碼的tfrecord生成程式碼,5位,6位的可以自行修改一下,也就一點程式碼。我因為有點晚了就先不改了,大家加油啦。

  • 先做些準備工作。
  • 所有字元的資料集,用於將字元轉化為它的下標數字。
    再存到tfrecord裡面。以便於後面讀取轉化為one-hot編碼使用。
import tensorflow as tf
import os
import random
import sys
from PIL import Image
import
numpy as np # 所有字元的資料集,用於將字元轉化為它的在列表中的下標數字 # 再存到tfrecord裡面。以便於後面讀取轉化為one-hot編碼使用 char_set = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'A', 'B', 'C',
'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
  • 在這裡寫成一個類,便於程式碼複用。
  • 大家可以根據需求稍作修改使用。

注意點:路徑記得把反斜槓換了,如 F:\checkimages\,要換為F:/checkimages/,最後面的斜槓別少,F:/checkimages也是不可以的

class Make_Tf_Record(object):
    # 將圖片資料轉化為tf檔案,打包成測試集和訓練集
def __init__(self, captcha_dir, tf_file_save_dir): self.char_set = char_set # 驗證碼的儲存路徑 self.captcha_dir = captcha_dir # 生成的tf檔案路徑 self.tf_file_save_dir = tf_file_save_dir
  • 判斷儲存tfrecord檔案的路徑裡面是否已經存在tfrecord檔案
  • 在後面會呼叫,很簡單的幾句程式碼
    def data_exist(self):
        # 判斷 record 檔案是否存在
       
        for split_name in ['train', 'test']:
            output_filename = os.path.join(self.tf_file_save_dir, split_name + '.tfrecords')
            if not tf.gfile.Exists(output_filename):
                return False
        return True
  • 獲取所有驗證碼圖片的具體路徑
    def get_all_captcha_filename(self, captcha_dir):
    	 # captcha_dir驗證碼圖片所在的路徑
        # 獲取所有驗證碼圖片的具體路徑
        captcha_filenames = []
        for filename in os.listdir(captcha_dir):
            # 獲取檔案路徑
            path = os.path.join(captcha_dir, filename)
            captcha_filenames.append(path)
        return captcha_filenames
  • 為轉化為tf檔案做準備的工作,幾乎都是固定的寫法。
  • 下面這個是為了將圖片的畫素值以bytes型別存進去,
  • 也可以說是:列表形狀,字串格式。 如"[[123,123,1],[,23,4,534]]"。
  • 這裡只是一個簡單說明,實際存進去的還是0和1組成的二進位制資料。
  • 不然也不叫bytes。讀取的時候再解碼一下就好,解碼tensorflow都有可以呼叫的函式,不慌。
  • int64_feature: 就是將驗證碼標籤的下標存進去

    def bytes_feature(self, values):
        # 用來存圖片畫素值
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


	    def int64_feature(self,values):
        # 判斷values是否是列表或者元組,如果不是轉為列表
        if not isinstance(values, (tuple, list)):
            values = [values]
        return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

  • 下面就是呼叫上面的2個函式,接收處理好的image資料和4個標籤值,序列化一下,返回一個物件。
  • 返回值呼叫一下SerializeToString()就可以寫進去了
    def image_to_tf_example(self, image_data, label0, label1, label2, label3):
        # 這裡預設是4位的字元型別,可以自行修改
        # 傳入圖片的資料和標籤,然後返回Example協議型別的資料
        # 全部以字元型別存進去, 分開字元存是為了多工訓練
        # print(label0, label1, label2, label3)
        
        # 先獲取每個字元對應的下標
        label0 = char_set.index(label0)
        label1 = char_set.index(label1)
        label2 = char_set.index(label2)
        label3 = char_set.index(label3)
        
        # return:返回一個Example物件,後來存進去的時候直接序列化一下 .SerializeToString()
        # 這其實是一個類字典的格式,讀取的時候就會發現確實是這樣
        return tf.train.Example(features=tf.train.Features(feature={
            'image': self.bytes_feature(image_data),
            'label0': self.int64_feature(label0),
            'label1': self.int64_feature(label1),
            'label2': self.int64_feature(label2),
            'label3': self.int64_feature(label3),
        }))
  • 劃重點:關鍵一步,程式碼長一點點。
  • 看註釋好理解
    # 把資料轉為TFRecord格式
    def _convert_dataset(self, split_name, filenames):
    	# 斷言,其實沒啥用,可以直接註釋
        assert split_name in ['train', 'test']

        with tf.Session() as sess:
            # 定義tfrecord檔案的路徑+名字
            output_filename = os.path.join(self.tf_file_save_dir, split_name + '.tfrecords')
            # 開啟一個tf檔案寫入器,取名為tfrecord_writer
            with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
                for i, filename in enumerate(filenames):
                    try:
                        sys.stdout.write('\r>> Converting image %d/%d' % (i + 1, len(filenames)))
                        sys.stdout.flush()

                        # 讀取圖片
                        image_data = Image.open(filename)
                        # 根據模型的結構resize
                        image_data = image_data.resize((224, 224))
                        # 灰度化,並轉化為一個儲存著畫素值的多維陣列
                        image_data = np.array(image_data.convert('L'))
                        # 將圖片轉化為bytes
                        image_data = image_data.tobytes()

                        # 獲取label,獲取路徑分割的陣列的最後一個也就是 abcd.jpg
                        labels = filename.split('/')[-1][0:4]
                        # 獲取前面4個標籤值
                        num_labels = []
                        for j in range(4):
                            num_labels.append(labels[j])

                        # 呼叫上面的函式,可以往回看看,生成protocol資料型別
                        example = self.image_to_tf_example(image_data, num_labels[0], num_labels[1], num_labels[2],
                                                           num_labels[3])
							# 呼叫write方法,直接寫入一個圖片的檔案。
							# SerializeToString在上面也提到啦,其實基本都是這麼寫,想多瞭解可以看一下函式介紹
                        tfrecord_writer.write(example.SerializeToString())

                    except IOError as e:
                        print('Could not read:', filename)
                        print('Error:', e)
                        print('Skip it\n')
        sys.stdout.write('\n')
        sys.stdout.flush()
  • 最後的一個主函式,直接建立物件後呼叫這個函式就可以生成tf檔案
  • 其實打亂的步驟可以去掉也是可以的,
  • 原因:get_all_captcha_filename中使用的是os.listdir(),這個函式返回的檔名稱列表就是亂的
    def start(self, test_num):
        # 判斷tfrecord檔案是否存在
        if self.data_exist():
            print('tfcecord檔案已存在')
        else:
            # 獲得所有圖片
            captcha_filenames = self.get_all_captcha_filename(self.captcha_dir)

            # 把資料切分為訓練集和測試集,並打亂
            # 隨機種子設定為0
            random.seed(0)
            random.shuffle(captcha_filenames)
            training_filenames = captcha_filenames[test_num:]
            testing_filenames = captcha_filenames[:test_num]
			
            # 資料轉換
            self._convert_dataset('train', training_filenames)
            self._convert_dataset('test', testing_filenames)

            print('生成tfcecord檔案')
            

小結:這只是我寫的一些自己以後可能會用到的東西順便分享一下,喜歡的化可以關注一下,以後會不斷得分享python各個方向的文章。爬蟲,資料分析,web,資料探勘。大家早透啦!!!