1. 程式人生 > >TensorFlow多線程輸入數據處理框架(四)——輸入數據處理框架

TensorFlow多線程輸入數據處理框架(四)——輸入數據處理框架

nat 數據解析 con NPU die thread 深度 variable glob

參考書

《TensorFlow:實戰Google深度學習框架》(第2版)

輸入數據處理的整個流程。

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# coding=utf-8 

"""
@author: Li Tian
@contact: [email protected]
@software: pycharm
@file: sample_data_deal.py
@time: 2019/2/8 20:30
@desc: 輸入數據處理框架
"""

from figuredata_deal.figure_deal_test2 import preprocess_for_train
import tensorflow as tf # 創建文件列表,並通過文件列表創建輸入文件隊列。在調用輸入數據處理流程前,需要統一所有原始數據的格式 # 並將它們存儲到TFRecord文件中。下面給出的文件列表應該包含所有提供訓練數據的TFRecord文件。 files = tf.train.match_filenames_once(file_pattern-*) filename_queue = tf.train.string_input_producer(files, shuffle=False) # 使用類似前面介紹的方法解析TFRecord文件裏的數據。這裏假設image中存儲的是圖像的原始數據,label為該
# 樣例所對應的標簽。height、width和channels給出了圖片的維度。 reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ image: tf.FixedLenFeature([], tf.string), label: tf.FixedLenFeature([], tf.int64),
height: tf.FixedLenFeature([], tf.int64), width: tf.FixedLenFeature([], tf.int64), channels: tf.FixedLenFeature([], tf.int64), } ) image, label = features[image], features[label] height, width = features[height], features[width] channels = features[channels] # 從原始圖像數據解析出像素矩陣,並根據圖像尺寸還原圖像。 decoded_image = tf.decode_raw(image, tf.uint8) decoded_image.set_shape([height, width, channels]) # 定義神經網絡輸入層圖片的大小 image_size = 299 # preprocess_for_train為前面提到的圖像預處理程序 distorted_image = preprocess_for_train(decoded_image, image_size, image_size, None) # 將處理後的圖像和標簽數據通過tf.train.shuffle_batch整理成神經網絡訓練時需要的batch。 min_after_dequeue = 10000 batch_size = 100 capacity = min_after_dequeue + 3 * batch_size image_batch, label_batch = tf.train.shuffle_batch([distorted_image, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue) # 定義神經網絡的結構以及優化過程, image_batch可以作為輸入提供給神經網絡的輸入層。 # label_batch則提供了輸入batch中樣例的正確答案。 # 學習率 learning_rate = 0.01 # inference是神經網絡的結構 logit = inference(image_batch) # loss是計算神經網絡的損失函數 loss = cal_loss(logit, label_batch) # 訓練過程 train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss) # 聲明會話並運行神經網絡的優化過程 with tf.Session() as sess: # 神經網絡訓練準備工作。這些工作包括變量初始化、線程啟動。 sess.run((tf.global_variables_initializer(), tf.local_variables_initializer())) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 神經網絡訓練過程。 TRAINING_ROUNDS = 5000 for i in range(TRAINING_ROUNDS): sess.run(train_step) # 停止所有線程 coord.request_stop() coord.join(threads)

TensorFlow多線程輸入數據處理框架(四)——輸入數據處理框架