TensorFlow學習--卷積神經網路訓練CIFAR-10資料集
阿新 • • 發佈:2019-02-07
CIFAR-10資料集
CIFAR-10資料集包含10個類的60000張32x32的彩色影象,每個類有6000張影象。 有50000張訓練影象和10000張測試影象。
10個分類明細及對應的部分圖片:
卷積神經網路訓練CIFAR-10資料集
表示(x,y)處以計算後經過激勵的輸出;
表示經過LRN層後作為下一層輸入的資料;
表示該層的特徵圖(feature map)總數
表示取特徵圖中間左右各個特徵圖求均值.
LNR的作用是增強響應比較大的神經元,並抑制增強響應比較小的神經元,增強泛化能力.LRN 會從附近的多個卷積核的響應中挑選比較大的反饋;因此,適用於對沒有上限邊界的啟用函式(如Relu ),但不適合有固定邊界且能抑制過大反饋的啟用函式(如sigmoid)。
程式碼及註釋
下載TensorFlow Models庫,獲得提供DIFAR-10資料的類:
tensorflow/models連結
# CIFAR-10資料集訓練
import sys
sys.path.append("/home/w/mycode/TensorFlow/CIFAR10/models/tutorials/image/cifar10")
# 下載和讀取CIFAR-10的類
import cifar10,cifar10_input
import tensorflow as tf
import numpy as np
import time
# 訓練輪數
max_steps = 3000
batch_size = 128
# 下載DIFAR-10資料的預設路徑
data_dir = '/tmp/cifar10_data/cifar-10-batches-bin'
# L1正則會製造稀疏的特徵,大部分無用特徵的權重會被置0
# L2正則會讓特徵不過大,使得特徵的權重比較均勻
# 使用正太分佈初始化權重並新增L2正則化,使用w1控制L2損失的大小
def variable_with_weight_loss(shape, stddev, w1):
# 從截斷的(2個標準差以內)正態分佈中輸出隨機值
var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
if w1 is not None:
# l2_loss(var)*w1
weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')
# 使用預設圖
tf.add_to_collection('losses', weight_loss)
return var
# 從Alex的網站下載並解壓到預設位置
cifar10.maybe_download_and_extract()
# 使用Reader操作構造CIFAR訓練需要的資料(特徵及其對應的label)
# 並對資料進行了資料增強(水平翻轉/隨機對比度亮度/隨機裁剪)以及資料的標準化
images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir,batch_size=batch_size)
# 使用Reader操作構建CIFAR評估的輸入(裁剪影象中間24*24大小的塊並進行資料標準化)
images_test, labels_test =cifar10_input.inputs(eval_data=True, data_dir=data_dir, batch_size=batch_size)
# 輸入影象佔位符(24*24 3通道)
image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
# 輸入標籤佔位符
label_holder = tf.placeholder(tf.int32, [batch_size])
# 卷積層1
# 64個5*5的卷積核3通道,不對第一個卷積層的權重加L2正則
weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2, w1=0.0)
# 卷積步長為1模式為SAME
kernel1 = tf.nn.conv2d(image_holder, weight1, [1, 1, 1, 1], padding='SAME')
# bias為0
bias1 = tf.Variable(tf.constant(0.0, shape=[64]))
# Adds bias to value
conv1 = tf.nn.relu(tf.nn.bias_add(kernel1, bias1))
# 最大池化 大小3*3步長2*2
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')
# 使用LRN對結果進行處理-Local Response Normalization-本地響應標準化
# 增強大的抑制小的,增強泛化能力
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75)
# 卷積層2
# 64個5*5的卷積核64通道,不加L2正則
weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2, w1=0.0)
# 卷積步長為1模式為SAME
kernel2 = tf.nn.conv2d(norm1, weight2, [1, 1, 1, 1], padding='SAME')
# bias為0.1
bias2 = tf.Variable(tf.constant(0.1, shape=[64]))
# Adds bias to value
conv2 = tf.nn.relu(tf.nn.bias_add(kernel2, bias2))
# LRN-本地響應標準化
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001/0.9, beta=0.75)
# 最大池化 大小3*3步長2*2
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')
# 全連線層
# 將樣本變成一維向量
reshape = tf.reshape(pool2, [batch_size, -1])
# 資料扁平化後的長度
dim = reshape.get_shape()[1].value
# weight初始化
weight3 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, w1=0.004)
# bias初始化
bias3 = tf.Variable(tf.constant(0.1, shape=[384]))
local3 = tf.nn.relu(tf.matmul(reshape, weight3) + bias3)
# 隱含節點數降為192
weight4 = variable_with_weight_loss(shape=[384, 192], stddev=0.04, w1=0.004)
bias4 = tf.Variable(tf.constant(0.1, shape=[192]))
local4 = tf.nn.relu(tf.matmul(local3, weight4) + bias4)
# 最終輸出10分類,正太分佈標準差設為上一隱含層節點數的倒數,不計入L2正則
weight5 = variable_with_weight_loss(shape=[192, 10], stddev=1/192.0, w1=0.0)
bias5 = tf.Variable(tf.constant(0.0, shape=[10]))
logits = tf.add(tf.matmul(local4, weight5), bias5)
# 計算CNN的loss
def loss(logits, labels):
labels = tf.cast(labels, tf.int64)
# 計算logits和labels之間的稀疏softmax交叉熵
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels, name='cross_entropy_per_example')
# 計算cross_entropy均值
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
# 將cross_entropy的loss新增到整體的loss裡
tf.add_to_collection('losses', cross_entropy_mean)
# 將整體losses的collection中的全部loss求和
return tf.add_n(tf.get_collection('losses'), name='total_loss')
# 將logits節點label_holder和傳入loss函式獲得最終的loss
loss = loss(logits, label_holder)
# 優化器選擇Adam,學習率選1e-3
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)
# 求top k的準確率(預設top1即輸出分數最高的一類)
top_k_op = tf.nn.in_top_k(logits, label_holder, 1)
# 建立預設的session()
sess = tf.InteractiveSession()
# 初始化全部的模型引數
tf.global_variables_initializer().run()
# 啟動圖片資料增強的執行緒佇列
tf.train.start_queue_runners()
for step in range(max_steps):
start_time = time.time()
# 使用sess的run方法執行images_train和labels_train的計算
image_batch, label_batch = sess.run([images_train, labels_train])
#
_, loss_value = sess.run([train_op, loss],
feed_dict={image_holder: image_batch, label_holder:label_batch})
# 記錄每個step的時間
duration = time.time() - start_time
if step % 10 == 0:
# 每秒訓練的樣本數量
examples_per_sec = batch_size / duration
# 訓練每個batch的時間
sec_per_batch = float(duration)
format_str = ('step %d, lass=%.2f (%.1f examples/sec; %.3f sec/batch)')
print format_str % (step, loss_value, examples_per_sec, sec_per_batch)
# 測試集樣本數量
num_examples = 10000
import math
# 總共多少個batch
num_inter = int(math.ceil(num_examples / batch_size))
true_count = 0
total_sample_count = num_inter * batch_size
step = 0
while step < num_inter:
# 使用sess的run方法獲取images_test和labels_test的batch
image_batch, label_batch = sess.run([images_test, labels_test])
# 預測正確的樣本數量
predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch,
label_holder: label_batch})
# 彙總預測正確的結果
true_count += np.sum(predictions)
step += 1
# 準確率評測結果
prediction = true_count / total_sample_count
print 'precision @ 1 = %.3f' % prediction
使用的技巧有:
- 對weights進行了L2正則化.
- 對圖片進行了翻轉 隨機剪下等資料增強,製造了更多樣本.
- 每個卷積-最大池化層後面使用了LRN層,增強了模型的泛化能力.
下載解壓後得到的資料:
輸出結果:
Downloading cifar-10-binary.tar.gz 100.0%
Successfully downloaded cifar-10-binary.tar.gz 170052171 bytes.
Filling queue with 20000 CIFAR images before starting to train. This will take a few minutes.
step 0, lass=4.67 (25.7 examples/sec; 4.989 sec/batch)
step 10, lass=3.69 (141.3 examples/sec; 0.906 sec/batch)
step 20, lass=3.20 (152.7 examples/sec; 0.838 sec/batch)
step 30, lass=2.72 (116.9 examples/sec; 1.095 sec/batch)
step 40, lass=2.48 (113.8 examples/sec; 1.124 sec/batch)
step 50, lass=2.38 (145.6 examples/sec; 0.879 sec/batch)
step 60, lass=2.40 (106.2 examples/sec; 1.205 sec/batch)
step 70, lass=2.10 (142.8 examples/sec; 0.896 sec/batch)
step 80, lass=2.01 (148.4 examples/sec; 0.862 sec/batch)
step 90, lass=2.04 (89.5 examples/sec; 1.430 sec/batch)
step 100, lass=1.97 (127.1 examples/sec; 1.007 sec/batch)
step 110, lass=1.92 (114.9 examples/sec; 1.114 sec/batch)
step 120, lass=1.93 (104.7 examples/sec; 1.223 sec/batch)
step 130, lass=1.85 (106.6 examples/sec; 1.201 sec/batch)
step 140, lass=1.82 (98.9 examples/sec; 1.294 sec/batch)
step 150, lass=1.92 (105.4 examples/sec; 1.214 sec/batch)
step 160, lass=1.89 (118.9 examples/sec; 1.077 sec/batch)
step 170, lass=1.68 (115.2 examples/sec; 1.111 sec/batch)
step 180, lass=1.76 (132.7 examples/sec; 0.965 sec/batch)
step 190, lass=1.85 (85.1 examples/sec; 1.504 sec/batch)
step 200, lass=1.75 (98.5 examples/sec; 1.299 sec/batch)
step 210, lass=1.81 (109.9 examples/sec; 1.164 sec/batch)
step 220, lass=1.75 (98.6 examples/sec; 1.298 sec/batch)
step 230, lass=1.79 (99.8 examples/sec; 1.282 sec/batch)
step 240, lass=1.64 (114.7 examples/sec; 1.116 sec/batch)
step 250, lass=1.91 (116.4 examples/sec; 1.100 sec/batch)
step 260, lass=1.68 (104.0 examples/sec; 1.231 sec/batch)
step 270, lass=1.45 (101.3 examples/sec; 1.264 sec/batch)
step 280, lass=1.45 (100.2 examples/sec; 1.277 sec/batch)
step 290, lass=1.67 (70.3 examples/sec; 1.820 sec/batch)
step 300, lass=1.73 (100.5 examples/sec; 1.273 sec/batch)
step 310, lass=1.61 (97.2 examples/sec; 1.317 sec/batch)
step 320, lass=1.63 (93.4 examples/sec; 1.371 sec/batch)
step 330, lass=1.56 (79.9 examples/sec; 1.601 sec/batch)
step 340, lass=1.45 (78.9 examples/sec; 1.622 sec/batch)
step 350, lass=1.54 (85.7 examples/sec; 1.494 sec/batch)
step 360, lass=1.43 (94.6 examples/sec; 1.354 sec/batch)
step 370, lass=1.54 (134.3 examples/sec; 0.953 sec/batch)
step 380, lass=1.52 (70.5 examples/sec; 1.816 sec/batch)
step 390, lass=1.70 (71.0 examples/sec; 1.803 sec/batch)
step 400, lass=1.63 (67.8 examples/sec; 1.887 sec/batch)
step 410, lass=1.35 (72.7 examples/sec; 1.760 sec/batch)
step 420, lass=1.54 (71.4 examples/sec; 1.794 sec/batch)
step 430, lass=1.60 (61.4 examples/sec; 2.084 sec/batch)
step 440, lass=1.65 (69.4 examples/sec; 1.844 sec/batch)
step 450, lass=1.38 (70.7 examples/sec; 1.810 sec/batch)
step 460, lass=1.31 (67.5 examples/sec; 1.898 sec/batch)
step 470, lass=1.55 (53.7 examples/sec; 2.385 sec/batch)
step 480, lass=1.42 (55.3 examples/sec; 2.314 sec/batch)
step 490, lass=1.33 (52.6 examples/sec; 2.432 sec/batch)
step 500, lass=1.31 (61.5 examples/sec; 2.082 sec/batch)
step 510, lass=1.68 (69.3 examples/sec; 1.846 sec/batch)
step 520, lass=1.47 (57.8 examples/sec; 2.213 sec/batch)
step 530, lass=1.50 (60.7 examples/sec; 2.109 sec/batch)
step 540, lass=1.25 (57.3 examples/sec; 2.236 sec/batch)
step 550, lass=1.61 (51.4 examples/sec; 2.492 sec/batch)
step 560, lass=1.33 (53.3 examples/sec; 2.403 sec/batch)
step 570, lass=1.48 (53.9 examples/sec; 2.373 sec/batch)
step 580, lass=1.42 (51.9 examples/sec; 2.468 sec/batch)
step 590, lass=1.49 (61.7 examples/sec; 2.073 sec/batch)
step 600, lass=1.52 (50.6 examples/sec; 2.532 sec/batch)
step 610, lass=1.31 (64.2 examples/sec; 1.994 sec/batch)
step 620, lass=1.61 (70.4 examples/sec; 1.819 sec/batch)
step 630, lass=1.30 (69.8 examples/sec; 1.835 sec/batch)
step 640, lass=1.45 (71.3 examples/sec; 1.796 sec/batch)
step 650, lass=1.43 (71.9 examples/sec; 1.780 sec/batch)
step 660, lass=1.58 (74.2 examples/sec; 1.726 sec/batch)
step 670, lass=1.29 (71.5 examples/sec; 1.790 sec/batch)
step 680, lass=1.11 (72.0 examples/sec; 1.777 sec/batch)
step 690, lass=1.20 (61.7 examples/sec; 2.074 sec/batch)
step 700, lass=1.36 (70.1 examples/sec; 1.827 sec/batch)
step 710, lass=1.39 (66.9 examples/sec; 1.914 sec/batch)
step 720, lass=1.49 (69.8 examples/sec; 1.833 sec/batch)
step 730, lass=1.63 (67.3 examples/sec; 1.901 sec/batch)
step 740, lass=1.42 (69.1 examples/sec; 1.852 sec/batch)
step 750, lass=1.34 (68.4 examples/sec; 1.871 sec/batch)
step 760, lass=1.33 (69.4 examples/sec; 1.843 sec/batch)
step 770, lass=1.48 (70.0 examples/sec; 1.830 sec/batch)
step 780, lass=1.34 (70.8 examples/sec; 1.808 sec/batch)
step 790, lass=1.39 (72.0 examples/sec; 1.778 sec/batch)
step 800, lass=1.31 (71.6 examples/sec; 1.788 sec/batch)
step 810, lass=1.30 (69.8 examples/sec; 1.834 sec/batch)
step 820, lass=1.27 (72.3 examples/sec; 1.770 sec/batch)
step 830, lass=1.36 (70.0 examples/sec; 1.830 sec/batch)
step 840, lass=1.36 (54.0 examples/sec; 2.369 sec/batch)
step 850, lass=1.40 (59.3 examples/sec; 2.157 sec/batch)
step 860, lass=1.18 (66.9 examples/sec; 1.912 sec/batch)
step 870, lass=1.29 (52.6 examples/sec; 2.433 sec/batch)
step 880, lass=1.18 (71.2 examples/sec; 1.799 sec/batch)
step 890, lass=1.28 (67.8 examples/sec; 1.888 sec/batch)
step 900, lass=1.13 (72.2 examples/sec; 1.773 sec/batch)
step 910, lass=1.26 (67.1 examples/sec; 1.908 sec/batch)
step 920, lass=1.30 (69.5 examples/sec; 1.841 sec/batch)
step 930, lass=1.23 (71.3 examples/sec; 1.795 sec/batch)
step 940, lass=1.20 (71.8 examples/sec; 1.782 sec/batch)
step 950, lass=1.50 (71.0 examples/sec; 1.803 sec/batch)
step 960, lass=1.31 (74.3 examples/sec; 1.724 sec/batch)
step 970, lass=1.12 (75.0 examples/sec; 1.707 sec/batch)
step 980, lass=1.31 (72.3 examples/sec; 1.770 sec/batch)
step 990, lass=1.25 (73.3 examples/sec; 1.746 sec/batch)
step 1000, lass=1.38 (71.4 examples/sec; 1.792 sec/batch)
step 1010, lass=1.21 (69.0 examples/sec; 1.856 sec/batch)
step 1020, lass=1.13 (68.5 examples/sec; 1.869 sec/batch)
step 1030, lass=1.18 (55.4 examples/sec; 2.309 sec/batch)
step 1040, lass=1.21 (61.3 examples/sec; 2.089 sec/batch)
step 1050, lass=1.20 (49.6 examples/sec; 2.580 sec/batch)
step 1060, lass=1.20 (53.5 examples/sec; 2.394 sec/batch)
step 1070, lass=1.18 (55.6 examples/sec; 2.301 sec/batch)
step 1080, lass=1.33 (58.4 examples/sec; 2.190 sec/batch)
step 1090, lass=1.28 (63.0 examples/sec; 2.032 sec/batch)
step 1100, lass=1.38 (63.5 examples/sec; 2.016 sec/batch)
step 1110, lass=1.22 (64.6 examples/sec; 1.983 sec/batch)
step 1120, lass=1.46 (64.5 examples/sec; 1.983 sec/batch)
step 1130, lass=1.21 (71.6 examples/sec; 1.787 sec/batch)
step 1140, lass=1.42 (72.4 examples/sec; 1.767 sec/batch)
step 1150, lass=1.20 (73.6 examples/sec; 1.738 sec/batch)
step 1160, lass=1.26 (70.1 examples/sec; 1.827 sec/batch)
step 1170, lass=1.13 (73.2 examples/sec; 1.748 sec/batch)
step 1180, lass=1.28 (68.1 examples/sec; 1.879 sec/batch)
step 1190, lass=1.23 (73.7 examples/sec; 1.737 sec/batch)
step 1200, lass=1.16 (73.5 examples/sec; 1.742 sec/batch)
step 1210, lass=1.17 (68.4 examples/sec; 1.871 sec/batch)
step 1220, lass=1.36 (72.3 examples/sec; 1.771 sec/batch)
step 1230, lass=1.21 (67.8 examples/sec; 1.887 sec/batch)
step 1240, lass=1.21 (67.7 examples/sec; 1.889 sec/batch)
step 1250, lass=1.21 (71.3 examples/sec; 1.795 sec/batch)
step 1260, lass=1.35 (71.2 examples/sec; 1.799 sec/batch)
step 1270, lass=1.22 (69.3 examples/sec; 1.847 sec/batch)
step 1280, lass=1.16 (71.9 examples/sec; 1.781 sec/batch)
step 1290, lass=1.14 (69.0 examples/sec; 1.856 sec/batch)
step 1300, lass=1.22 (72.0 examples/sec; 1.777 sec/batch)
step 1310, lass=1.42 (71.0 examples/sec; 1.803 sec/batch)
step 1320, lass=1.33 (72.0 examples/sec; 1.777 sec/batch)
step 1330, lass=1.30 (72.4 examples/sec; 1.769 sec/batch)
step 1340, lass=1.30 (68.1 examples/sec; 1.881 sec/batch)
step 1350, lass=1.21 (72.0 examples/sec; 1.779 sec/batch)
step 1360, lass=1.04 (72.5 examples/sec; 1.766 sec/batch)
step 1370, lass=1.41 (71.3 examples/sec; 1.796 sec/batch)
step 1380, lass=1.14 (74.2 examples/sec; 1.726 sec/batch)
step 1390, lass=1.35 (69.9 examples/sec; 1.831 sec/batch)
step 1400, lass=1.33 (71.5 examples/sec; 1.790 sec/batch)
step 1410, lass=1.44 (70.0 examples/sec; 1.828 sec/batch)
step 1420, lass=1.15 (70.4 examples/sec; 1.819 sec/batch)
step 1430, lass=1.19 (72.4 examples/sec; 1.769 sec/batch)
step 1440, lass=1.18 (58.2 examples/sec; 2.201 sec/batch)
step 1450, lass=1.30 (56.1 examples/sec; 2.280 sec/batch)
step 1460, lass=1.11 (50.1 examples/sec; 2.557 sec/batch)
step 1470, lass=1.24 (53.7 examples/sec; 2.385 sec/batch)
step 1480, lass=1.52 (50.7 examples/sec; 2.527 sec/batch)
step 1490, lass=1.08 (52.6 examples/sec; 2.433 sec/batch)
step 1500, lass=1.12 (54.7 examples/se