TensorFlow多線程輸入數據處理框架(四)——輸入數據處理框架
阿新 • • 發佈:2019-02-08
nat 數據解析 con NPU die thread 深度 variable glob
參考書
《TensorFlow:實戰Google深度學習框架》(第2版)
輸入數據處理的整個流程。
#!/usr/bin/env python # -*- coding: UTF-8 -*- # coding=utf-8 """ @author: Li Tian @contact: [email protected] @software: pycharm @file: sample_data_deal.py @time: 2019/2/8 20:30 @desc: 輸入數據處理框架 """ from figuredata_deal.figure_deal_test2 import preprocess_for_trainimport tensorflow as tf # 創建文件列表,並通過文件列表創建輸入文件隊列。在調用輸入數據處理流程前,需要統一所有原始數據的格式 # 並將它們存儲到TFRecord文件中。下面給出的文件列表應該包含所有提供訓練數據的TFRecord文件。 files = tf.train.match_filenames_once(‘file_pattern-*‘) filename_queue = tf.train.string_input_producer(files, shuffle=False) # 使用類似前面介紹的方法解析TFRecord文件裏的數據。這裏假設image中存儲的是圖像的原始數據,label為該# 樣例所對應的標簽。height、width和channels給出了圖片的維度。 reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ ‘image‘: tf.FixedLenFeature([], tf.string), ‘label‘: tf.FixedLenFeature([], tf.int64),‘height‘: tf.FixedLenFeature([], tf.int64), ‘width‘: tf.FixedLenFeature([], tf.int64), ‘channels‘: tf.FixedLenFeature([], tf.int64), } ) image, label = features[‘image‘], features[‘label‘] height, width = features[‘height‘], features[‘width‘] channels = features[‘channels‘] # 從原始圖像數據解析出像素矩陣,並根據圖像尺寸還原圖像。 decoded_image = tf.decode_raw(image, tf.uint8) decoded_image.set_shape([height, width, channels]) # 定義神經網絡輸入層圖片的大小 image_size = 299 # preprocess_for_train為前面提到的圖像預處理程序 distorted_image = preprocess_for_train(decoded_image, image_size, image_size, None) # 將處理後的圖像和標簽數據通過tf.train.shuffle_batch整理成神經網絡訓練時需要的batch。 min_after_dequeue = 10000 batch_size = 100 capacity = min_after_dequeue + 3 * batch_size image_batch, label_batch = tf.train.shuffle_batch([distorted_image, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue) # 定義神經網絡的結構以及優化過程, image_batch可以作為輸入提供給神經網絡的輸入層。 # label_batch則提供了輸入batch中樣例的正確答案。 # 學習率 learning_rate = 0.01 # inference是神經網絡的結構 logit = inference(image_batch) # loss是計算神經網絡的損失函數 loss = cal_loss(logit, label_batch) # 訓練過程 train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss) # 聲明會話並運行神經網絡的優化過程 with tf.Session() as sess: # 神經網絡訓練準備工作。這些工作包括變量初始化、線程啟動。 sess.run((tf.global_variables_initializer(), tf.local_variables_initializer())) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 神經網絡訓練過程。 TRAINING_ROUNDS = 5000 for i in range(TRAINING_ROUNDS): sess.run(train_step) # 停止所有線程 coord.request_stop() coord.join(threads)
TensorFlow多線程輸入數據處理框架(四)——輸入數據處理框架