MNIST資料集輸出識別準確率用到的函式方法
阿新 • • 發佈:2018-12-13
MNIST資料集輸出識別準確率
MINST資料集:
提供6w張28*28畫素點的0~9手寫數字圖片和標籤,用於訓練;
提供1w張2828畫素點的0~9手寫數字圖片和標籤,用於測試。 每張圖片的784個畫素點(2828=784)組成長度為784的一維陣列,作為輸入特徵:
eg:[0. 0. 0. 0. 0. 0. 0.380 0.376 0.301 0.462 … … … 0.239 0. 0. 0. 0. 0. 0. 0. 0.]
圖片的標籤以一維陣列形式給出,每個元素表示對應分類出現的概率。
eg:[0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]這個表示數字6的概率是100%其他數字的概率均是0,意思就是圖片應該是阿拉伯數字6。
from tensorflow. examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./data/', one_hot=True)
Train and Validation:→訓練和驗證模型引數
- train-images-idx3-ubyte.gz
- train-labels-idx1-ubyte.gz
Test:→測試模型
- t10k-images-idx3-ubyte.gz
- t10k-labels-idx1-ubyte.gz
返回各子集樣本數:
mnist.train.num_examples
mnist.validation. num_examples
mnist.test.num_examples
返回標籤和資料:
>>>mnist.train.labels[0] # 表示第0張圖片的標籤
array([0., 0., 0., 0., 0., 0., 0., 1., 0., 0.])
>>>mnist.train.images[0] # 表示第0張圖片的784個畫素點
array([0. , 0. ,0.
0. , 0. ,0.
0. , 0. , 0.
0. , 0. ,0.
..................................
])
# 總共784個畫素點
收一小撮資料,準備喂入神經網路訓練:
>>>BATCH_SIZE = 200 # 定義一小撮是多少
>>>xs, ys = mnist.train.next_batch(BATCH_SIZE) # 從訓練集中隨機抽取BATCH_SIZE組個數據和標籤
>>>print("xs shape:", xs.shape)
xs.shape: (200,784) # 200行資料 每個資料有784個畫素點
>>>print("ys shape:", ys.shape)
ys shape: (200,10) # 200行資料 每個資料有10個元素是輸出的分類
下面是一些常用的函式:
tf.get_collection("") # 從集合中取全部變數,生成一個列表
tf.add_n([]) # 列表內對應元素相加
tf.cast(x,dtype) # 把x轉為dtype型別
tf.argmax(x,axis) # 返回最大值所在的索引號,如:tf.argmax([1,0,0],1) 返回0
os.path.join("home", "name") # 返回home/name
字串.spilt() # 按指定字串對字串進行切片,返回分割後的列表
with tf.Graph().as_default() as g: # 其內定義的節點在計算圖g
儲存模型:
saver = tf.train.Saver() # 例項化saver物件
with tf.Session as sess:
for i in range(STEPS):
if i % 輪數 == 0:
saver.save(sess, os.path.join(MODEL_SAVE_PATH,MODEL_NAME), global_step=global_step)
載入模型:
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(儲存路徑)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
例項化可還原滑動平均值的saver
ema = tf.train.ExponentialMovingAverage(滑動平均基數)
ema_restore = ema.variables_to_restore()
saver = tf.train.Saver(ema_restore)
準確率計算方法:
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))