卷積神經網路簡單的應用(三):模型測試
阿新 • • 發佈:2019-01-04
- 模型測試
模型訓練好之後通過重新載入模型的方式進行模型測試,使用Tensorflow中的Saver物件。相關程式碼如下:
主函式為:def test_cnn(x_data): output = create_cnn(4) saver = tf.train.Saver() with tf.Session() as sess: #載入訓練好的模型 saver.restore(sess,"./model/cnn.model-2100") preject = tf.argmax(output,1) x_in = np.array(x_data) #keep_prob設定為1 label = sess.run(preject,feed_dict={X:[x_in],keep_prob:1}) return label
if __name__ == '__main__': isTrain = 2 if 1 == isTrain: X = tf.placeholder(tf.float32,[None,200,150,3]) Y = tf.placeholder(tf.float32,[None,4]) keep_prob = tf.placeholder(tf.float32) train_cnn(xdata,ydata) if 2 == isTrain: #將測試資料放在相應的檔案中 path_list = ['./0','./1','./2','./3'] for p in path_list: file_info = os.listdir(p) for file_name in file_info: x_data = read_test_data(p+'/'+file_name) if type(x_data) == type(None): print('==>',p) continue #沒有這句,會出現問題 tf.reset_default_graph() X = tf.placeholder(tf.float32,[None,200,150,3]) keep_prob = tf.placeholder(tf.float32) l = test_cnn(x_data) label = ['girl','beauty girl','boy','handsome boy'] print(p,':',file_name,'====>',label[l[0]])