1. 程式人生 > >Tensorflow學習筆記-輸入資料處理框架

Tensorflow學習筆記-輸入資料處理框架

Created with Raphaël 2.1.0獲取輸入檔案列表建立輸入檔案佇列從檔案佇列讀取資料整理成Batch作為神經網路的輸入設計損失函式選擇梯度下降法訓練

  對應的程式碼流程如下:

    # 建立檔案列表,並通過檔案列表來建立檔案佇列。在呼叫輸入資料處理流程前,需要統一
    # 所有的原始資料格式,並將它們儲存到TFRecord檔案中
    # match_filenames_once 獲取符合正則表示式的所有檔案
    files = tf.train.match_filenames_once('path/to/file-*-*')
    # 將檔案列表生成檔案佇列
filename_queue = tf.train.string_input_producer(files,shuffle=True) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) # image:儲存影象中的原始資料 # label該樣本所對應的標籤 # width,height,channel features = tf.parse_single_example(serialized_example,features={ 'image'
: tf.FixedLenFeature([],tf.string), 'label': tf.FixedLenFeature([], tf.int64), 'width': tf.FixedLenFeature([], tf.int64), 'heigth': tf.FixedLenFeature([], tf.int64), 'channel': tf.FixedLenFeature([], tf.int64) }) image, label = features['image'], features['label'
] width, height = features['width'], features['height'] channel = features['channel'] # 將原始影象資料解析出畫素矩陣,並根據影象尺寸還原糖影象。 decode_image = tf.decode_raw(image) decode_image.set_shape([width,height,channel]) # 神經網路的輸入大小 image_size = 299 # 對影象進行預處理操作,比對亮度、對比度、隨機裁剪等操作 distorted_image = propocess_train(decode_image,image_size,None) # shuffle_batch中的引數 min_after_dequeue = 1000 batch_size = 100 capacity = min_after_dequeue + 3*batch_size image_batch,label_batch = tf.train.shuffle_batch([distorted_image,label], batch_size=batch_size,capacity=capacity, min_after_dequeue=min_after_dequeue) logit = inference(image_batch) loss = cal_loss(logit,label_batch) train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss) with tf.Session() as sess: # 變數初始化 tf.global_variables_initializer().run() # 執行緒初始化和啟動 coord = tf.train.Coordinator() theads = tf.train.start_queue_runners(sess=sess,coord=coord) for i in range(STEPS): sess.run(train_step) # 停止所有執行緒 coord.request_stop() coord.join(threads)