1. 程式人生 > >MNIST資料集輸出識別準確率用到的函式方法

MNIST資料集輸出識別準確率用到的函式方法

MNIST資料集輸出識別準確率

MINST資料集:

提供6w張28*28畫素點的0~9手寫數字圖片和標籤,用於訓練;

提供1w張2828畫素點的0~9手寫數字圖片和標籤,用於測試。 在這裡插入圖片描述 每張圖片的784個畫素點(2828=784)組成長度為784的一維陣列,作為輸入特徵

eg:[0. 0. 0. 0. 0. 0. 0.380 0.376 0.301 0.462 … … … 0.239 0. 0. 0. 0. 0. 0. 0. 0.]

圖片的標籤以一維陣列形式給出,每個元素表示對應分類出現的概率。

eg:[0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]這個表示數字6的概率是100%其他數字的概率均是0,意思就是圖片應該是阿拉伯數字6。

from tensorflow.
examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('./data/', one_hot=True)

Train and Validation:→訓練和驗證模型引數

  • train-images-idx3-ubyte.gz
  • train-labels-idx1-ubyte.gz

Test:→測試模型

  • t10k-images-idx3-ubyte.gz
  • t10k-labels-idx1-ubyte.gz

返回各子集樣本數:

mnist.train.num_examples
mnist.validation.
num_examples mnist.test.num_examples

返回標籤和資料:

>>>mnist.train.labels[0]   # 表示第0張圖片的標籤
array([0., 0., 0., 0., 0., 0., 0., 1., 0., 0.])
>>>mnist.train.images[0]   # 表示第0張圖片的784個畫素點
array([0.             , 0.              ,0.
       0.             , 0.              ,0.
       0.             , 0.              ,
0. 0. , 0. ,0. .................................. ]) # 總共784個畫素點

收一小撮資料,準備喂入神經網路訓練:

>>>BATCH_SIZE = 200                   # 定義一小撮是多少
>>>xs, ys = mnist.train.next_batch(BATCH_SIZE)  # 從訓練集中隨機抽取BATCH_SIZE組個數據和標籤
>>>print("xs shape:", xs.shape)
xs.shape: (200,784)   # 200行資料 每個資料有784個畫素點
>>>print("ys shape:", ys.shape)
ys shape: (200,10)    # 200行資料 每個資料有10個元素是輸出的分類

下面是一些常用的函式:

tf.get_collection("")        # 從集合中取全部變數,生成一個列表
tf.add_n([])                 # 列表內對應元素相加
tf.cast(x,dtype)			 # 把x轉為dtype型別
tf.argmax(x,axis)			 # 返回最大值所在的索引號,如:tf.argmax([1,0,0],1) 返回0
os.path.join("home", "name") # 返回home/name
字串.spilt()           	   # 按指定字串對字串進行切片,返回分割後的列表
with tf.Graph().as_default() as g:  # 其內定義的節點在計算圖g

儲存模型:

saver = tf.train.Saver()    # 例項化saver物件
with tf.Session as sess:
	for i in range(STEPS):
        if i % 輪數 == 0:
            saver.save(sess, os.path.join(MODEL_SAVE_PATH,MODEL_NAME), global_step=global_step)

載入模型:

with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state(儲存路徑)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)

例項化可還原滑動平均值的saver

ema = tf.train.ExponentialMovingAverage(滑動平均基數)
ema_restore = ema.variables_to_restore()
saver = tf.train.Saver(ema_restore)

準確率計算方法:

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))