1. 程式人生 > >卷積神經網路簡單的應用(三):模型測試

卷積神經網路簡單的應用(三):模型測試

  1. 模型測試
    模型訓練好之後通過重新載入模型的方式進行模型測試,使用Tensorflow中的Saver物件。相關程式碼如下:
    def test_cnn(x_data):
        output = create_cnn(4)
        saver = tf.train.Saver()
        with tf.Session() as sess:
            #載入訓練好的模型
            saver.restore(sess,"./model/cnn.model-2100")
            preject = tf.argmax(output,1)
            x_in = np.array(x_data)
            #keep_prob設定為1
            label = sess.run(preject,feed_dict={X:[x_in],keep_prob:1})
        return label
    主函式為:
    if __name__ == '__main__':
        isTrain = 2    
        if 1 == isTrain:        
            X = tf.placeholder(tf.float32,[None,200,150,3])
            Y = tf.placeholder(tf.float32,[None,4])
            
            keep_prob = tf.placeholder(tf.float32)
            train_cnn(xdata,ydata)
        if 2 == isTrain:
            #將測試資料放在相應的檔案中    
            path_list = ['./0','./1','./2','./3']
            for p in path_list:
                file_info = os.listdir(p)
                for file_name in file_info:
                    x_data = read_test_data(p+'/'+file_name)
                    if type(x_data) == type(None):
                        print('==>',p)
                        continue
                    #沒有這句,會出現問題
                    tf.reset_default_graph()  
                    X = tf.placeholder(tf.float32,[None,200,150,3])
                    keep_prob = tf.placeholder(tf.float32)
                    l = test_cnn(x_data)
                    label = ['girl','beauty girl','boy','handsome boy']               
                    print(p,':',file_name,'====>',label[l[0]])