1. 程式人生 > >Tensorflow mnist資料集操作

Tensorflow mnist資料集操作

import sys import numpy as np import matplotlib.pyplot as plt import tensorflow as tf import argparse import matplotlib.image as mpimg from skimage import io  #利用skilt-image 將圖片讀成numpy陣列 from tensorflow.examples.tutorials.mnist import input_data #引入tensorflow的mnist手寫數字識別庫 class Mnister:      #定義一個類     def _init_(self):         pass     def learn_mnist(self):         mnist = input_data.read_data_sets('datasets',one_hot = True)         #tensorflow的input_data和read_data_sets,第一個是資料集存放路徑,第二個引數是標籤集的格式         #one_hot獨熱編碼,N個維度來對N個類別進行編碼,並且對於每個類別,只有一個維度有效,記作數字1 ;其它維度均記作數字0.         x_train_data = mnist.train.images  #取出輸入訓練訊號資料X_train,為矩陣形式,一行代表一個樣本有784維(28*28)         y_train_label = mnist.train.labels  #取出輸入訓練訊號標籤y_train,為one_hot為行的矩陣形式,每一行代表對應樣本的正確結果,在這是10維(0—9)         x_validation_data = mnist.validation.images #取出驗證集的資料集         y_validation_label = mnist.validation.labels #取出驗證集的資料集對應的標籤(10維0-9)         x_test_data = mnist.test.images    #取出測試集的資料集         y_test_label = mnist.test.labels    #取出測試集的資料集對應的標籤         print('x_train_data:{0} y_train_label:{1}'.format(x_train_data.shape,y_train_label.shape))         print(' x_validation_data:{0} y_validation_label:{1}'.format( x_validation_data.shape,y_validation_label.shape))         print('x_test_data:{0} y_test_label:{1}'.format(x_test_data.shape,y_test_label.shape))         image_raw = (x_train_data[1]*255).astype(int)#將第二個資料拿出,由0—1的浮點數,轉化為0-255的整數灰度值。         image = image_raw.reshape(28,28)    #將784維的行向量轉化為28*28的矩陣         label = y_train_label[1]            #讀取該樣本的正確結果標籤         idx = 0                             #定義索引         for item in label:             if 1==item:                 break   #標籤向量one-hot此元素為1時終止迴圈,             idx += 1         plt.title('digit:{0}'.format(idx))#將樣本標籤的正確結果顯示在圖片上         plt.imshow(image,cmap='gray')     #以灰度來顯示圖象         plt.show()                       #顯示圖片 def main(self):      mnister = Mnister()     mnister.learn_mnist() if '__main__' == __name__:     parser = argparse.ArgumentParser()     parser.add_argument('--datda_dir',type = str,default='datasets',                         help = 'Directory for storing input data')     FLAGS,unparsed = parser.parse_known_args()     tf.app.run(main = main,argv = [sys.argv[0]]+unparsed)