1. 程式人生 > >TensorFlow學習--卷積神經網路訓練CIFAR-10資料集

TensorFlow學習--卷積神經網路訓練CIFAR-10資料集

CIFAR-10資料集

CIFAR-10資料集包含10個類的60000張32x32的彩色影象,每個類有6000張影象。 有50000張訓練影象和10000張測試影象。
10個分類明細及對應的部分圖片:

這裡寫圖片描述

卷積神經網路訓練CIFAR-10資料集

這裡寫圖片描述

ax,yi表示(x,y)處以kerneli計算後經過激勵的輸出;
bx,yi表示經過LRN層後作為下一層輸入的資料;
N表示該層的特徵圖(feature map)總數
n表示取特徵圖中間左右各n/2個特徵圖求均值.

LNR的作用是增強響應比較大的神經元,並抑制增強響應比較小的神經元,增強泛化能力.LRN 會從附近的多個卷積核的響應中挑選比較大的反饋;因此,適用於對沒有上限邊界的啟用函式(如Relu ),但不適合有固定邊界且能抑制過大反饋的啟用函式(如sigmoid)。

程式碼及註釋

下載TensorFlow Models庫,獲得提供DIFAR-10資料的類:
tensorflow/models連結

# CIFAR-10資料集訓練
import sys
sys.path.append("/home/w/mycode/TensorFlow/CIFAR10/models/tutorials/image/cifar10")
# 下載和讀取CIFAR-10的類
import cifar10,cifar10_input
import tensorflow as tf
import numpy as np
import time

# 訓練輪數
max_steps = 3000
batch_size = 128
# 下載DIFAR-10資料的預設路徑 data_dir = '/tmp/cifar10_data/cifar-10-batches-bin' # L1正則會製造稀疏的特徵,大部分無用特徵的權重會被置0 # L2正則會讓特徵不過大,使得特徵的權重比較均勻 # 使用正太分佈初始化權重並新增L2正則化,使用w1控制L2損失的大小 def variable_with_weight_loss(shape, stddev, w1): # 從截斷的(2個標準差以內)正態分佈中輸出隨機值 var = tf.Variable(tf.truncated_normal(shape, stddev=stddev)) if
w1 is not None: # l2_loss(var)*w1 weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss') # 使用預設圖 tf.add_to_collection('losses', weight_loss) return var # 從Alex的網站下載並解壓到預設位置 cifar10.maybe_download_and_extract() # 使用Reader操作構造CIFAR訓練需要的資料(特徵及其對應的label) # 並對資料進行了資料增強(水平翻轉/隨機對比度亮度/隨機裁剪)以及資料的標準化 images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir,batch_size=batch_size) # 使用Reader操作構建CIFAR評估的輸入(裁剪影象中間24*24大小的塊並進行資料標準化) images_test, labels_test =cifar10_input.inputs(eval_data=True, data_dir=data_dir, batch_size=batch_size) # 輸入影象佔位符(24*24 3通道) image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3]) # 輸入標籤佔位符 label_holder = tf.placeholder(tf.int32, [batch_size]) # 卷積層1 # 64個5*5的卷積核3通道,不對第一個卷積層的權重加L2正則 weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2, w1=0.0) # 卷積步長為1模式為SAME kernel1 = tf.nn.conv2d(image_holder, weight1, [1, 1, 1, 1], padding='SAME') # bias為0 bias1 = tf.Variable(tf.constant(0.0, shape=[64])) # Adds bias to value conv1 = tf.nn.relu(tf.nn.bias_add(kernel1, bias1)) # 最大池化 大小3*3步長2*2 pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME') # 使用LRN對結果進行處理-Local Response Normalization-本地響應標準化 # 增強大的抑制小的,增強泛化能力 norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # 卷積層2 # 64個5*5的卷積核64通道,不加L2正則 weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2, w1=0.0) # 卷積步長為1模式為SAME kernel2 = tf.nn.conv2d(norm1, weight2, [1, 1, 1, 1], padding='SAME') # bias為0.1 bias2 = tf.Variable(tf.constant(0.1, shape=[64])) # Adds bias to value conv2 = tf.nn.relu(tf.nn.bias_add(kernel2, bias2)) # LRN-本地響應標準化 norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001/0.9, beta=0.75) # 最大池化 大小3*3步長2*2 pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME') # 全連線層 # 將樣本變成一維向量 reshape = tf.reshape(pool2, [batch_size, -1]) # 資料扁平化後的長度 dim = reshape.get_shape()[1].value # weight初始化 weight3 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, w1=0.004) # bias初始化 bias3 = tf.Variable(tf.constant(0.1, shape=[384])) local3 = tf.nn.relu(tf.matmul(reshape, weight3) + bias3) # 隱含節點數降為192 weight4 = variable_with_weight_loss(shape=[384, 192], stddev=0.04, w1=0.004) bias4 = tf.Variable(tf.constant(0.1, shape=[192])) local4 = tf.nn.relu(tf.matmul(local3, weight4) + bias4) # 最終輸出10分類,正太分佈標準差設為上一隱含層節點數的倒數,不計入L2正則 weight5 = variable_with_weight_loss(shape=[192, 10], stddev=1/192.0, w1=0.0) bias5 = tf.Variable(tf.constant(0.0, shape=[10])) logits = tf.add(tf.matmul(local4, weight5), bias5) # 計算CNN的loss def loss(logits, labels): labels = tf.cast(labels, tf.int64) # 計算logits和labels之間的稀疏softmax交叉熵 cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=labels, name='cross_entropy_per_example') # 計算cross_entropy均值 cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') # 將cross_entropy的loss新增到整體的loss裡 tf.add_to_collection('losses', cross_entropy_mean) # 將整體losses的collection中的全部loss求和 return tf.add_n(tf.get_collection('losses'), name='total_loss') # 將logits節點label_holder和傳入loss函式獲得最終的loss loss = loss(logits, label_holder) # 優化器選擇Adam,學習率選1e-3 train_op = tf.train.AdamOptimizer(1e-3).minimize(loss) # 求top k的準確率(預設top1即輸出分數最高的一類) top_k_op = tf.nn.in_top_k(logits, label_holder, 1) # 建立預設的session() sess = tf.InteractiveSession() # 初始化全部的模型引數 tf.global_variables_initializer().run() # 啟動圖片資料增強的執行緒佇列 tf.train.start_queue_runners() for step in range(max_steps): start_time = time.time() # 使用sess的run方法執行images_train和labels_train的計算 image_batch, label_batch = sess.run([images_train, labels_train]) # _, loss_value = sess.run([train_op, loss], feed_dict={image_holder: image_batch, label_holder:label_batch}) # 記錄每個step的時間 duration = time.time() - start_time if step % 10 == 0: # 每秒訓練的樣本數量 examples_per_sec = batch_size / duration # 訓練每個batch的時間 sec_per_batch = float(duration) format_str = ('step %d, lass=%.2f (%.1f examples/sec; %.3f sec/batch)') print format_str % (step, loss_value, examples_per_sec, sec_per_batch) # 測試集樣本數量 num_examples = 10000 import math # 總共多少個batch num_inter = int(math.ceil(num_examples / batch_size)) true_count = 0 total_sample_count = num_inter * batch_size step = 0 while step < num_inter: # 使用sess的run方法獲取images_test和labels_test的batch image_batch, label_batch = sess.run([images_test, labels_test]) # 預測正確的樣本數量 predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch, label_holder: label_batch}) # 彙總預測正確的結果 true_count += np.sum(predictions) step += 1 # 準確率評測結果 prediction = true_count / total_sample_count print 'precision @ 1 = %.3f' % prediction

使用的技巧有:

  1. 對weights進行了L2正則化.
  2. 對圖片進行了翻轉 隨機剪下等資料增強,製造了更多樣本.
  3. 每個卷積-最大池化層後面使用了LRN層,增強了模型的泛化能力.

下載解壓後得到的資料:
這裡寫圖片描述

這裡寫圖片描述

輸出結果:

Downloading cifar-10-binary.tar.gz 100.0%
Successfully downloaded cifar-10-binary.tar.gz 170052171 bytes.
Filling queue with 20000 CIFAR images before starting to train. This will take a few minutes.

step 0, lass=4.67 (25.7 examples/sec; 4.989 sec/batch)
step 10, lass=3.69 (141.3 examples/sec; 0.906 sec/batch)
step 20, lass=3.20 (152.7 examples/sec; 0.838 sec/batch)
step 30, lass=2.72 (116.9 examples/sec; 1.095 sec/batch)
step 40, lass=2.48 (113.8 examples/sec; 1.124 sec/batch)
step 50, lass=2.38 (145.6 examples/sec; 0.879 sec/batch)
step 60, lass=2.40 (106.2 examples/sec; 1.205 sec/batch)
step 70, lass=2.10 (142.8 examples/sec; 0.896 sec/batch)
step 80, lass=2.01 (148.4 examples/sec; 0.862 sec/batch)
step 90, lass=2.04 (89.5 examples/sec; 1.430 sec/batch)
step 100, lass=1.97 (127.1 examples/sec; 1.007 sec/batch)
step 110, lass=1.92 (114.9 examples/sec; 1.114 sec/batch)
step 120, lass=1.93 (104.7 examples/sec; 1.223 sec/batch)
step 130, lass=1.85 (106.6 examples/sec; 1.201 sec/batch)
step 140, lass=1.82 (98.9 examples/sec; 1.294 sec/batch)
step 150, lass=1.92 (105.4 examples/sec; 1.214 sec/batch)
step 160, lass=1.89 (118.9 examples/sec; 1.077 sec/batch)
step 170, lass=1.68 (115.2 examples/sec; 1.111 sec/batch)
step 180, lass=1.76 (132.7 examples/sec; 0.965 sec/batch)
step 190, lass=1.85 (85.1 examples/sec; 1.504 sec/batch)
step 200, lass=1.75 (98.5 examples/sec; 1.299 sec/batch)
step 210, lass=1.81 (109.9 examples/sec; 1.164 sec/batch)
step 220, lass=1.75 (98.6 examples/sec; 1.298 sec/batch)
step 230, lass=1.79 (99.8 examples/sec; 1.282 sec/batch)
step 240, lass=1.64 (114.7 examples/sec; 1.116 sec/batch)
step 250, lass=1.91 (116.4 examples/sec; 1.100 sec/batch)
step 260, lass=1.68 (104.0 examples/sec; 1.231 sec/batch)
step 270, lass=1.45 (101.3 examples/sec; 1.264 sec/batch)
step 280, lass=1.45 (100.2 examples/sec; 1.277 sec/batch)
step 290, lass=1.67 (70.3 examples/sec; 1.820 sec/batch)
step 300, lass=1.73 (100.5 examples/sec; 1.273 sec/batch)
step 310, lass=1.61 (97.2 examples/sec; 1.317 sec/batch)
step 320, lass=1.63 (93.4 examples/sec; 1.371 sec/batch)
step 330, lass=1.56 (79.9 examples/sec; 1.601 sec/batch)
step 340, lass=1.45 (78.9 examples/sec; 1.622 sec/batch)
step 350, lass=1.54 (85.7 examples/sec; 1.494 sec/batch)
step 360, lass=1.43 (94.6 examples/sec; 1.354 sec/batch)
step 370, lass=1.54 (134.3 examples/sec; 0.953 sec/batch)
step 380, lass=1.52 (70.5 examples/sec; 1.816 sec/batch)
step 390, lass=1.70 (71.0 examples/sec; 1.803 sec/batch)
step 400, lass=1.63 (67.8 examples/sec; 1.887 sec/batch)
step 410, lass=1.35 (72.7 examples/sec; 1.760 sec/batch)
step 420, lass=1.54 (71.4 examples/sec; 1.794 sec/batch)
step 430, lass=1.60 (61.4 examples/sec; 2.084 sec/batch)
step 440, lass=1.65 (69.4 examples/sec; 1.844 sec/batch)
step 450, lass=1.38 (70.7 examples/sec; 1.810 sec/batch)
step 460, lass=1.31 (67.5 examples/sec; 1.898 sec/batch)
step 470, lass=1.55 (53.7 examples/sec; 2.385 sec/batch)
step 480, lass=1.42 (55.3 examples/sec; 2.314 sec/batch)
step 490, lass=1.33 (52.6 examples/sec; 2.432 sec/batch)
step 500, lass=1.31 (61.5 examples/sec; 2.082 sec/batch)
step 510, lass=1.68 (69.3 examples/sec; 1.846 sec/batch)
step 520, lass=1.47 (57.8 examples/sec; 2.213 sec/batch)
step 530, lass=1.50 (60.7 examples/sec; 2.109 sec/batch)
step 540, lass=1.25 (57.3 examples/sec; 2.236 sec/batch)
step 550, lass=1.61 (51.4 examples/sec; 2.492 sec/batch)
step 560, lass=1.33 (53.3 examples/sec; 2.403 sec/batch)
step 570, lass=1.48 (53.9 examples/sec; 2.373 sec/batch)
step 580, lass=1.42 (51.9 examples/sec; 2.468 sec/batch)
step 590, lass=1.49 (61.7 examples/sec; 2.073 sec/batch)
step 600, lass=1.52 (50.6 examples/sec; 2.532 sec/batch)
step 610, lass=1.31 (64.2 examples/sec; 1.994 sec/batch)
step 620, lass=1.61 (70.4 examples/sec; 1.819 sec/batch)
step 630, lass=1.30 (69.8 examples/sec; 1.835 sec/batch)
step 640, lass=1.45 (71.3 examples/sec; 1.796 sec/batch)
step 650, lass=1.43 (71.9 examples/sec; 1.780 sec/batch)
step 660, lass=1.58 (74.2 examples/sec; 1.726 sec/batch)
step 670, lass=1.29 (71.5 examples/sec; 1.790 sec/batch)
step 680, lass=1.11 (72.0 examples/sec; 1.777 sec/batch)
step 690, lass=1.20 (61.7 examples/sec; 2.074 sec/batch)
step 700, lass=1.36 (70.1 examples/sec; 1.827 sec/batch)
step 710, lass=1.39 (66.9 examples/sec; 1.914 sec/batch)
step 720, lass=1.49 (69.8 examples/sec; 1.833 sec/batch)
step 730, lass=1.63 (67.3 examples/sec; 1.901 sec/batch)
step 740, lass=1.42 (69.1 examples/sec; 1.852 sec/batch)
step 750, lass=1.34 (68.4 examples/sec; 1.871 sec/batch)
step 760, lass=1.33 (69.4 examples/sec; 1.843 sec/batch)
step 770, lass=1.48 (70.0 examples/sec; 1.830 sec/batch)
step 780, lass=1.34 (70.8 examples/sec; 1.808 sec/batch)
step 790, lass=1.39 (72.0 examples/sec; 1.778 sec/batch)
step 800, lass=1.31 (71.6 examples/sec; 1.788 sec/batch)
step 810, lass=1.30 (69.8 examples/sec; 1.834 sec/batch)
step 820, lass=1.27 (72.3 examples/sec; 1.770 sec/batch)
step 830, lass=1.36 (70.0 examples/sec; 1.830 sec/batch)
step 840, lass=1.36 (54.0 examples/sec; 2.369 sec/batch)
step 850, lass=1.40 (59.3 examples/sec; 2.157 sec/batch)
step 860, lass=1.18 (66.9 examples/sec; 1.912 sec/batch)
step 870, lass=1.29 (52.6 examples/sec; 2.433 sec/batch)
step 880, lass=1.18 (71.2 examples/sec; 1.799 sec/batch)
step 890, lass=1.28 (67.8 examples/sec; 1.888 sec/batch)
step 900, lass=1.13 (72.2 examples/sec; 1.773 sec/batch)
step 910, lass=1.26 (67.1 examples/sec; 1.908 sec/batch)
step 920, lass=1.30 (69.5 examples/sec; 1.841 sec/batch)
step 930, lass=1.23 (71.3 examples/sec; 1.795 sec/batch)
step 940, lass=1.20 (71.8 examples/sec; 1.782 sec/batch)
step 950, lass=1.50 (71.0 examples/sec; 1.803 sec/batch)
step 960, lass=1.31 (74.3 examples/sec; 1.724 sec/batch)
step 970, lass=1.12 (75.0 examples/sec; 1.707 sec/batch)
step 980, lass=1.31 (72.3 examples/sec; 1.770 sec/batch)
step 990, lass=1.25 (73.3 examples/sec; 1.746 sec/batch)
step 1000, lass=1.38 (71.4 examples/sec; 1.792 sec/batch)
step 1010, lass=1.21 (69.0 examples/sec; 1.856 sec/batch)
step 1020, lass=1.13 (68.5 examples/sec; 1.869 sec/batch)
step 1030, lass=1.18 (55.4 examples/sec; 2.309 sec/batch)
step 1040, lass=1.21 (61.3 examples/sec; 2.089 sec/batch)
step 1050, lass=1.20 (49.6 examples/sec; 2.580 sec/batch)
step 1060, lass=1.20 (53.5 examples/sec; 2.394 sec/batch)
step 1070, lass=1.18 (55.6 examples/sec; 2.301 sec/batch)
step 1080, lass=1.33 (58.4 examples/sec; 2.190 sec/batch)
step 1090, lass=1.28 (63.0 examples/sec; 2.032 sec/batch)
step 1100, lass=1.38 (63.5 examples/sec; 2.016 sec/batch)
step 1110, lass=1.22 (64.6 examples/sec; 1.983 sec/batch)
step 1120, lass=1.46 (64.5 examples/sec; 1.983 sec/batch)
step 1130, lass=1.21 (71.6 examples/sec; 1.787 sec/batch)
step 1140, lass=1.42 (72.4 examples/sec; 1.767 sec/batch)
step 1150, lass=1.20 (73.6 examples/sec; 1.738 sec/batch)
step 1160, lass=1.26 (70.1 examples/sec; 1.827 sec/batch)
step 1170, lass=1.13 (73.2 examples/sec; 1.748 sec/batch)
step 1180, lass=1.28 (68.1 examples/sec; 1.879 sec/batch)
step 1190, lass=1.23 (73.7 examples/sec; 1.737 sec/batch)
step 1200, lass=1.16 (73.5 examples/sec; 1.742 sec/batch)
step 1210, lass=1.17 (68.4 examples/sec; 1.871 sec/batch)
step 1220, lass=1.36 (72.3 examples/sec; 1.771 sec/batch)
step 1230, lass=1.21 (67.8 examples/sec; 1.887 sec/batch)
step 1240, lass=1.21 (67.7 examples/sec; 1.889 sec/batch)
step 1250, lass=1.21 (71.3 examples/sec; 1.795 sec/batch)
step 1260, lass=1.35 (71.2 examples/sec; 1.799 sec/batch)
step 1270, lass=1.22 (69.3 examples/sec; 1.847 sec/batch)
step 1280, lass=1.16 (71.9 examples/sec; 1.781 sec/batch)
step 1290, lass=1.14 (69.0 examples/sec; 1.856 sec/batch)
step 1300, lass=1.22 (72.0 examples/sec; 1.777 sec/batch)
step 1310, lass=1.42 (71.0 examples/sec; 1.803 sec/batch)
step 1320, lass=1.33 (72.0 examples/sec; 1.777 sec/batch)
step 1330, lass=1.30 (72.4 examples/sec; 1.769 sec/batch)
step 1340, lass=1.30 (68.1 examples/sec; 1.881 sec/batch)
step 1350, lass=1.21 (72.0 examples/sec; 1.779 sec/batch)
step 1360, lass=1.04 (72.5 examples/sec; 1.766 sec/batch)
step 1370, lass=1.41 (71.3 examples/sec; 1.796 sec/batch)
step 1380, lass=1.14 (74.2 examples/sec; 1.726 sec/batch)
step 1390, lass=1.35 (69.9 examples/sec; 1.831 sec/batch)
step 1400, lass=1.33 (71.5 examples/sec; 1.790 sec/batch)
step 1410, lass=1.44 (70.0 examples/sec; 1.828 sec/batch)
step 1420, lass=1.15 (70.4 examples/sec; 1.819 sec/batch)
step 1430, lass=1.19 (72.4 examples/sec; 1.769 sec/batch)
step 1440, lass=1.18 (58.2 examples/sec; 2.201 sec/batch)
step 1450, lass=1.30 (56.1 examples/sec; 2.280 sec/batch)
step 1460, lass=1.11 (50.1 examples/sec; 2.557 sec/batch)
step 1470, lass=1.24 (53.7 examples/sec; 2.385 sec/batch)
step 1480, lass=1.52 (50.7 examples/sec; 2.527 sec/batch)
step 1490, lass=1.08 (52.6 examples/sec; 2.433 sec/batch)
step 1500, lass=1.12 (54.7 examples/se