1. 程式人生 > >非常精簡的Mnist分類,基於tensorflow框架

非常精簡的Mnist分類,基於tensorflow框架

一、介紹

    基於tensorflow框架實現的Mnist資料分類。程式碼主要包括網路結構的搭建,訓練超引數的匯入和儲存,損失函式地繪製等。不足之處是在網路結尾沒用使用softmax函式,而直接使用了tanh輸出了分類結果。下面請看程式碼的詳細介紹

二、程式碼

  • 匯入必要的包檔案,需要的包我直接通過pycharm匯入的,能匯入的原因是採用了anaconda3底下的python.exe,新建工程的時候,從外部匯入
# 需要使用到的包檔案
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
import argparse
import os
# 加上這一句能夠使Plot繪製出來的圖更精美
sns.set_style("whitegrid")
  • 訓練引數設定,詳細介紹請看程式碼註釋,主要採用了argparse,該模組的好處是直接可以在執行時修改引數,比如:python main.py --data_dir= "**"
parser = argparse.ArgumentParser(description="Network for image classification")                
parser.add_argument('--data_dir', default='tem/data', help='Directory for training data')   # Mnist資料集存放位置
parser.add_argument('--result_dir', default='tem/result')                                   # 訓練結果的存放
parser.add_argument('--model_dir', default='model/', help='the place of saving networks parameters')   #訓練引數的存放地址
parser.add_argument('--batch_size', default=32)
parser.add_argument('--print_loss', default=10) # 每隔10次迭代列印損失值
parser.add_argument('--plot_loss', default=100) # 每隔100次迭代繪製損失函式曲線
parser.add_argument('--learning_rate', default=0.001, type=float) # 學習率,不易設定過大
parser.add_argument('--n_iterations', default=10000, type=int) # 迭代次數
args = parser.parse_args() # 將--*的*傳遞給arg,呼叫時直接使用args.data_dir這樣的結構
  • 網路結構搭建
w_init = tf.random_normal_initializer(stddev=0.01)   # 權重w初始化,標準差為0.01,平均值0
def network(x): # 啟用函式都為relu,除了輸出
    layers1 = tf.layers.conv2d(x, 32, 3, 1, padding='same', activation=tf.nn.relu, kernel_initializer=w_init)      # 32個卷積核,3x3卷積核大小,步長為1,padding為'same',即輸出大小為input/stride,向上取整
    layers2 = tf.layers.conv2d(layers1, 62, 3, 1, padding='same', activation=tf.nn.relu, kernel_initializer=w_init) 
    layers2_flatten = tf.contrib.layers.flatten(layers2)  # 將layers2的輸出"磨平",降低相關維度,以供全連線層工作
    layers3 = tf.layers.dense(layers2_flatten, 200, activation=tf.nn.relu, kernel_initializer=w_init) # 200為全連線層單元個數,其它的痛卷積函式類似
    output = tf.layers.dense(layers3, 10, activation=tf.nn.tanh, kernel_initializer=w_init) # 使用tanh作為輸出,比sigmoid好,因為sigmoid是有0項,不利於網路訓練
    return output
  • 訓練網路,詳細介紹看註釋
def training():
    input_x = tf.placeholder(tf.float32, [None, 28, 28, 1])  # 放置佔位矩陣
    label_y = tf.placeholder(tf.float32, [None, 10])
    output_y = network(input_x)                              # 前向傳播
    loss = tf.reduce_sum(tf.square(label_y-output_y))        # 計算同便籤損失
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate).minimize(loss) # 使用Adam優化

    init_all_v = tf.global_variables_initializer()           # 張量初始化函式
    sess = tf.InteractiveSession()                           
    sess.run(init_all_v)                                     # 實行張量初始化
    saver = load_model(sess)                                 # 匯入之前訓練過的引數,如果沒有則打印出來
    mnist = read_data_sets(args.data_dir, one_hot=True)      # 往指定目錄讀取Mnist資料集
    print('start training')                                   
    plot_loss = []                                           # 損失值快取
    for i in range(args.n_iterations):                             
        batch_x, batch_y = mnist.train.next_batch(args.batch_size)      # 讀取Batch_size
        batch_x = batch_x.reshape([args.batch_size, 28, 28, 1])         # 維度匹配   
        y = np.zeros([args.batch_size, 10])                             # 下面的操作是因為我讀到的標籤是6,8,9直接對應的圖片的數字,所以將這些數字向量化,以便訓練
        for j in range(args.batch_size):
            k = batch_y[j].astype(np.int)
            y[j, k] = 1.
        batch_y = y
        d_loss, _ = sess.run([loss, optimizer], feed_dict={input_x:batch_x, label_y:batch_y})  # 執行
        plot_loss.append(d_loss)

        if i % args.print_loss == 0 and i > 0:
            print('Iteration is : %d, Loss is: %f' % (i, d_loss))   # 列印損失
        if i % args.plot_loss == 0 and i > 0:            # 繪圖
            plt.figure(figsize=(6*1.1618, 6))
            plt.plot(range(len(plot_loss)), plot_loss)
            plt.xlabel('iteration times')
            plt.ylabel('lost')
            plt.show()
        if i % 500 == 0 and i > 0:
            save_model(saver, sess, i)
  • 模組的匯入與儲存
def save_model(saver, sess, step):   # 儲存模組
    saver.save(sess, os.path.join(args.model_dir, "classification"), global_step=step)
def load_model(sess):                # 匯入模組 
    saver = tf.train.Saver()
    checkpoint = tf.train.get_checkpoint_state(args.model_dir)
    if checkpoint and checkpoint.model_checkpoint_path:
        saver.restore(sess, checkpoint.model_checkpoint_path)
        print("Successfully loaded:", checkpoint.model_checkpoint_path)
    else:
        print("Could not find any old weights!")
    return saver
  • 主函式
def main(_):
    training()
if __name__ == "__main__":
    tf.app.run()

 

從上往下黏貼就行,貼到IDE下就可以執行,還可以列印損失函式

 

鬼知道為什麼下降這麼快,,,