1. 程式人生 > >C++呼叫tensorflow 訓練好的模型

C++呼叫tensorflow 訓練好的模型

這個東西我弄了好久!!!!

中間都想放棄了,但是我不服啊,還好弄出來了!!!分享給大家,希望可以幫助到大家哦~嘻嘻嘻

我看到有些說只能安裝32位的py,我開始也是這樣的,但是安裝TensorFlow做測試的時候,就一直有問題,所以呀,我就換成了ananconda 

安裝這個網上一大堆,自己可以好好看看哦!

然後就是新建一個C++的工程

1.把ananconda的減壓後,將裡面的inlcude和libs兩個資料夾拷貝到sln的同一級目錄下

2.然後開啟libs,複製一份python35.lib,並命名為python35_d.lib

3.C++->常規->附加包含目錄,輸入..\include;

4.連結器->常規->附加目錄項,輸入..\libs;

5.連結器->輸入->附加依賴項,新增python35_d.lib;

6. python35.dll拷貝到Debug目錄下(與Test.exe同目錄)

7.將py拷貝到Debug目錄下(與Test.exe同目錄)

8.將你訓練好的模型新建一個資料夾拷貝到C++專案資料夾裡來

1-----測試圖片   2.3就是py裡面的東西  4就是你的模型

好了  我開始貼程式碼了~~~

C++:

void testImage(char * path)
{
	try{
		Py_Initialize();
		PyEval_InitThreads();
		PyObject*pFunc = NULL;
		PyObject*pArg = NULL;
		PyObject* module = NULL;
		module = PyImport_ImportModule("myModel");//myModel:Python檔名
		if (!module) {
			printf("cannot open module!");
			//Py_Finalize();
		}
		pFunc = PyObject_GetAttrString(module, "test_one_image");//test_one_image:Python檔案中的函式名
		if (!pFunc) {
			printf("cannot open FUNC!");
			//Py_Finalize();
		}
		//開始呼叫model
		pArg = Py_BuildValue("(s)", path);
		if (module != NULL) {
			PyGILState_STATE gstate;
		        gstate = PyGILState_Ensure();
			 PyEval_CallObject(pFunc, pArg);
			PyGILState_Release(gstate);
	
		}
	}
	 catch (exception& e)
	 {
		 cout << "Standard exception: " << e.what() << endl;
	 }
}

python:
def test_one_image(test_dir):
    image = Image.open(test_dir)
    plt.imshow(image)
    plt.axis('off')
    plt.show()
    image = image.resize([32, 32])
    image_array = np.array(image)

    with tf.Graph().as_default():
        image = tf.cast(image_array, tf.float32)
        image = tf.reshape(image, [1, 32, 32, 3])#調整image的形狀
        p = mmodel(image, 1)
        logits = tf.nn.softmax(p)
        x = tf.placeholder(tf.float32, shape=[32, 32, 3])
        saver = tf.train.Saver()
        model_path='E:/MyProject/MachineLearning/call64PY/test/model/'
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            ckpt = tf.train.get_checkpoint_state(model_path)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, tf.train.latest_checkpoint('E:/MyProject/MachineLearning/call64PY/test/model/'))
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('載入ckpt成功!')
            else:
                print('error')

            prediction = sess.run(logits, feed_dict={x: image_array})
            max_index = np.argmax(prediction)
            if max_index == 0:
                print('case0: %.6f' % prediction[:, 0])
                return result
            else:
                print('-case1: %.6f' % prediction[:, 1])
                return result2

這裡面好多坑啊~~~