TensorFlow-cifar訓練與測試(可對自己資料進行分類和測試)
阿新 • • 發佈:2018-11-30
第一部分:測試軟硬體
硬體:NVIDIA-GTX1080
軟體:Windows7、python3.6.5、tensorflow-gpu-1.4.0
第二部分:資料下載
第三部分:程式碼分步展示
第一步:匯入tensorflow
import os
from PIL import Image
import numpy as np
import tensorflow as tf
第二步:定義模型函式
def model_load(x_, y_, keep_prob): conv1 = tf.layers.conv2d(x_, 20, 5, activation=tf.nn.relu) pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2]) conv2 = tf.layers.conv2d(pool1, 40, 4, activation=tf.nn.relu) pool2 = tf.layers.max_pooling2d(conv2, [2, 2], [2, 2]) pool2_reshape = tf.layers.flatten(pool2) fc = tf.layers.dense(pool2_reshape, 400, activation=tf.nn.relu) fc_dropout = tf.layers.dropout(fc, keep_prob) y_out = tf.layers.dense(fc_dropout, 3) return y_out
第三步:定義資料讀取函式
def read_data(data_dir): datas = [] labels = [] fpaths = [] for fname in os.listdir(data_dir): fpath = os.path.join(data_dir, fname) fpaths.append(fpath) image = Image.open(fpath) data = np.array(image) / 255.0 label = int(fname.split("_")[0]) datas.append(data) labels.append(label) datas = np.array(datas) labels = np.array(labels) return fpaths, datas, labels
第四步:定義網路輸入輸出
#define x y size x_ = tf.placeholder(tf.float32, [None, 32, 32, 3]) y_ = tf.placeholder(tf.int32, [None]) keep_prob = tf.placeholder(tf.float32) #load model y_out = model_load(x_, y_, keep_prob) #define loss losses = tf.nn.softmax_cross_entropy_with_logits(labels = tf.one_hot(y_, 3), logits = y_out) mean_loss = tf.reduce_mean(losses) optimizer = tf.train.AdamOptimizer(learning_rate=1e-2).minimize(mean_loss) #define predict top1 predicted_labels = tf.argmax(y_out, 1)
第五步:讀取cifar資料集
#read data
data_dir = "data"
model_path = "model/cifar_model"
fpaths, datas, labels = read_data(data_dir)
第六步:訓練和測試(100%準確率)
#define saver and sess
saver = tf.train.Saver()
sess = tf.Session()
#train or test
train = True
if train:
sess.run(tf.global_variables_initializer())
for step in range(150):
_, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict={x_:datas, y_:labels, keep_prob:0.25})
if step % 10 == 0:
print("step = {}\tmean loss = {}".format(step, mean_loss_val))
saver.save(sess, model_path + str(150))
else:
labels_name = ["plane", "car", "bird"]
saver.restore(sess, model_path)
predicted_labels_val = sess.run(predicted_labels, feed_dict={x_:datas, y_:labels, keep_prob:0})
img_count = 0
for fpath in fpaths:
print("{}\t{} =? {}".format(fpath, labels_name[labels[img_count]], labels_name[predicted_labels_val[img_count]]))
img_count += 1
第七步:備註
若發現loss不下降,說明訓練過擬合了,原因是引數初始化問題。這個沒關係,刪除model,重新訓練一遍即可。
任何問題請加唯一QQ2258205918(名稱samylee)!