1. 程式人生 > >在OpenCV中實現決策樹和隨機森林

在OpenCV中實現決策樹和隨機森林

目錄

1.決策樹

2.隨機森林


1.決策樹

需要注意的點

     Ptr<TrainData> data_set = TrainData::loadFromCSV("mushroom.data",//檔名
		                                                    0,//第0行略過
		                                                    0,
		                                            1,//區間[0,1)為儲存了響應列
		                                            "cat[0-22]",//0-22行均為類別資料
		                                            ',',//資料間的分隔符號為","
	                                                     '?');//丟失資料用"?"表示

1.資料型別有cat和ord之分,具體可以參閱統計資料定義:

  https://zhidao.baidu.com/question/1964314134743418500.html

2.預設的響應列的格式(第2-第3行)是前閉後開;

3.分割訓練集和資料集時,資料集的順序會大幅度的影響決策樹的結果:

	data_set->setTrainTestSplitRatio(0.90, false);

4.對於概率權重的設定,你可以理解為對識別某一類物體具有相對更高的準確率(請注意我的矩陣初始化方法);

	float _priors[] = { 1.0,10.0 };
	Mat priors(1, 2, CV_32F, _priors);
	dtree->setPriors(priors);//為所有的答案設定權重

5.在OpenCV3.0以上的版本使用決策樹與隨機森林所繼承的類都是RTrees,相比與DTrees而言,新的類RTrees能夠處理資料集中的缺失資料,建模的唯一區別就是在生成隨機森林時,需要設定樹的終止生成條件,預設是100棵樹:

 forest_mushroom->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER 
                               + TermCriteria::EPS, 100, 0.01));//隨機森林的終止標準

6.模型的儲存與載入:

	dtree->save("dtree_01.xml");//儲存
        Ptr<RTrees> dtree = RTrees::load("dtree_01.xml");//載入訓練模型

7.如果你想用載入後的模型進行資料的分類和迴歸,請務必手動建立訓練集、測試集和驗證集;這是因為如果你的資料集裡面含有字母類別,那麼opencv會以預設的方式轉化為ASCALL碼並歸一化,在這種情況下如果依舊使用預設的方式載入驗證集,必然會在使用predict時程式崩潰。

決策樹的訓練程式碼:

毒蘑菇的資料集:https://github.com/oreillymedia/Learning-OpenCV-3_examples/tree/master/mushroom

#include<iostream>
#include<opencv2/opencv.hpp>

using namespace cv;
using namespace ml;
using namespace std;

//1.生成訓練集結構體物件指標
	Ptr<TrainData> data_set = TrainData::loadFromCSV("mushroom.data",//檔名
		                                           0,//第0行略過
		                                           0,
		                                           1,//區間[0,1)為儲存了響應列
		                                 "cat[0-22]",//0-22行均為類別資料
		                                         ',',//資料間的分隔符號為","
	                                                '?');//丟失資料用"?"表示

//2.驗證資料讀取的正確性
	int n_samples = data_set->getNSamples();
	if (n_samples == 0)
	{
		cerr << "Could not read file: mushroom.data" << endl;
		exit(-1);
	}
	else
	{
		cout << "Read " << n_samples << " samples from mushroom.data" << endl;
	}

	//3.分割訓練集和測試集,比例為9:1,不打亂資料集的順序
	data_set->setTrainTestSplitRatio(0.90, false);
	int n_train_samples = data_set->getNTrainSamples();
	int n_test_samples = data_set->getNTestSamples();
	Mat trainMat = data_set->getTrainSamples();
	//4.決策樹
	//4.1 建立
	Ptr<RTrees> dtree = RTrees::create();
	//4.2 引數設定
	dtree->setMaxDepth(8); //樹的最大深度
	dtree->setMinSampleCount(10); //節點樣本數的最小值
	dtree->setRegressionAccuracy(0.01f);
	dtree->setUseSurrogates(false);//是否允許使用替代分叉點處理丟失的資料
	dtree->setMaxCategories(15);//決策樹的最大預分類數量
	dtree->setCVFolds(0);//如果 CVFolds>1 那麼就使用k-fold交叉修建決策樹 其中k=CVFolds
	dtree->setUse1SERule(true);//True 表示使用更大力度的修剪,這會導致樹的規模更小,
                                                           但準確性更差,用於解決過擬合問題
	dtree->setTruncatePrunedTree(true);//是否刪掉被減枝的部分
	float _priors[] = { 1.0,10.0 };
	Mat priors(1, 2, CV_32F, _priors);
	dtree->setPriors(priors);//為所有的答案設定權重
	//4.3 訓練
	dtree->train(data_set);
	//4.4 計算訓練誤差
	Mat results;
	float train_performance = dtree->calcError(data_set,
		false,//true 表示使用測試集  false 表示使用訓練集
		results);
	//5 訓練集的結果分析
	vector<String> names;
	data_set->getNames(names);
	Mat flags = data_set->getVarSymbolFlags();
	Mat expected_responses = data_set->getResponses();
	int good = 0, bad = 0, total = 0;
	for (int i = 0; i < data_set->getNTrainSamples(); ++i)
	{
		float received = results.at<float>(i, 0);
		float expected = expected_responses.at<float>(i, 0);
		String r_str = names[(int)received];
		String e_str = names[(int)expected];
		if (received != expected)
		{
		   bad++;
                   cout << "Expected: " << e_str << " ,got: " << r_str << endl;
		}
		else good++;
		total++;
	}
	cout << "Correct answers: " << (float(good) / total) << "% " << endl;
	cout << "Incorrect answers: " << (float(bad) / total) << "% " << endl;

	//6 測試集的結果分析
	float test_performance = dtree->calcError(data_set, true, results);
	cout << "Performance on training data: " << train_performance << "%" << endl;
	cout << "Performance on test data: " << test_performance << "%" << endl;

	//儲存
	dtree->save("dtree_01.xml");

資料的預測:

Ptr<RTrees> dtree = RTrees::load("dtree_01.xml");
Mat sample = (Mat_<float>(1, 22)
          << 2, 3, 4, 5, 6, 7, 8, 2, 9, 1, 8, 10, 10, 4, 4, 11, 4, 12, 11, 9, 10, 13);
float result = dtree->predict(sample);

2.隨機森林

使用隨機森林需要注意的點:

1.隨機森林的引數設定比較簡單,不要要考慮數的減枝等因素,但計算變數的重要性會需要額外的計算時間:

	 Ptr<RTrees> forest_mushroom = RTrees::create();
	 forest_mushroom->setMaxDepth(10); //樹的最大深度
	 forest_mushroom->setRegressionAccuracy(0.01f);//設定迴歸精度
	 forest_mushroom->setMinSampleCount(10);//節點的最小樣本數量
	 forest_mushroom->setMaxCategories(15);//最大預分類數
	 forest_mushroom->setCalculateVarImportance(true);//計算變數的重要性
	 forest_mushroom->setActiveVarCount(4);//樹節點隨機選擇特徵子集的大小
                                               //終止標準
	 forest_mushroom->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER + 
                                                    TermCriteria::EPS, 100, 0.01));

2.在資料的預測部分 與決策樹有些微的不同,隨機森林的預測有兩種形式,第一種是直接基於投票結果給出響應值:


		 Mat sample = testSample.row(i);
		 float r = forest_mushroom->predict(sample);
		 r = fabs((float)r - testResMat.at<float>(i)) <= FLT_EPSILON ? 1 : 0;
		

第二種是使用getVotes計算票選矩陣(注意:所計算出來的票選矩陣已經提取了最大公約數)我們可以利用票選結果來計算概率;

                 Mat sample = testSample.row(i);
		 Mat result;
		 forest_mushroom->getVotes(sample,result,0);

3.隨機森林的泛化能力很強,只需要較少的樣本就可以實現很高精度的分類,比如在8400多個蘑菇樣本中,我只使用了840樣本生成隨機森林,最後的分類準確率依然高達96%

4.載入、讀取以及驗證集的使用請參閱隨機森林部分

建模程式碼:

    //建立蘑菇分類的隨機森林
     //1.構建訓練集和測試集
     Ptr<TrainData> data_set = TrainData::loadFromCSV("mushroom.data",//檔名
	                                               0,//第0行略過
	                                               0,
	                                               1,//區間[0,1)為儲存了響應列
	                                     "cat[0-22]",//0-22行均為類別資料
	                                             ',',//資料間的分隔符號為","
	                                            '?');//丟失資料用"?"表示
     //2.驗證資料讀取的正確性
	 int n_samples = data_set->getNSamples();
	 if (n_samples == 0)
	 {
		 cerr << "Could not read file: mushroom.data" << endl;
		 exit(-1);
	 }
	 else
	 {
		 cout << "Read " << n_samples << " samples from mushroom.data" << endl;
	 }

	 //3.分割訓練集和測試集,比例為9:1,打亂資料集的順序
	 data_set->setTrainTestSplitRatio(0.90, true);
	 int n_train_samples = data_set->getNTrainSamples();
	 int n_test_samples = data_set->getNTestSamples();

	 //4.隨機森林
	 Ptr<RTrees> forest_mushroom = RTrees::create();
	 forest_mushroom->setMaxDepth(10); //樹的最大深度
	 forest_mushroom->setRegressionAccuracy(0.01f);//設定迴歸精度
	 forest_mushroom->setMinSampleCount(10);//節點的最小樣本數量
	 forest_mushroom->setMaxCategories(15);//最大預分類數
	 forest_mushroom->setCalculateVarImportance(true);//計算變數的重要性
	 forest_mushroom->setActiveVarCount(4);//樹節點隨機選擇特徵子集的大小
	 forest_mushroom->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER +  
                                            TermCriteria::EPS, 100, 0.01));//終止標準
	 //訓練模型
	 forest_mushroom->train(data_set);
	 //計算訓練集和測試集的誤差
	 float correct_Train_answer = 0;
	 float correct_Test_answer = 0;
	 //1.訓練集
	 Mat trainSample = data_set->getTrainSamples();
	 Mat trainResMat = data_set->getTrainResponses();
	 for (int i = 0; i < trainSample.rows; i++)
	 {
		 Mat sample = trainSample.row(i);
		 float r = forest_mushroom->predict(sample);
		 r = fabs((float)r - trainResMat.at<float>(i)) <= FLT_EPSILON ? 1 : 0;
		 correct_Train_answer += r;
	 }
	 float r1 = correct_Train_answer / n_train_samples;
	 //2.測試集
	 Mat testSample =  data_set->getTestSamples();
	 Mat testResMat = data_set->getTestResponses();
	 for (int i = 0; i < testSample.rows; i++)
	 {
		 Mat sample = testSample.row(i);
		 float r = forest_mushroom->predict(sample);
		 r = fabs((float)r - testResMat.at<float>(i)) <= FLT_EPSILON ? 1 : 0;
		 correct_Test_answer += r;
	 }
	 float r2 = correct_Test_answer / n_test_samples;
	 //3.輸出結果
	 cout << "trainSet Accuracy: " << r1* 100 << "%" << endl;
	 cout << "testSet Accuracy:  " << r2 * 100 << "%" << endl;
	 //4.儲存模型
	 forest_mushroom->save("forest_mushroom.xml");

 

相關資料

1.OpenCV 3.4 官方ML庫手冊:

https://docs.opencv.org/3.4.0/dd/ded/group__ml.html

2.隨機森林:

https://blog.csdn.net/akadiao/article/details/79413713

https://www.cnblogs.com/hrlnw/p/3850459.html

https://blog.csdn.net/akadiao/article/details/79413713

https://blog.csdn.net/wishchin/article/details/78662797

https://blog.csdn.net/wishchin/article/details/78662797

3.OpenCV作者Gray的Github關於毒蘑菇的資料集資料:

https://github.com/oreillymedia/Learning-OpenCV-3_examples/tree/master/mushroom

4.統計資料的定義:

https://zhidao.baidu.com/question/1964314134743418500.html

5.OpenCV機器學習部落格:

https://www.cnblogs.com/denny402/p/5032232.html

6.ASCALL碼對照表

https://blog.csdn.net/u011930916/article/details/79623922

7.OpenCV中Mat物件使用全解:

https://blog.csdn.net/guyuealian/article/details/70159660

8.OpenCV資料型別的位數總結:

https://blog.csdn.net/lcgwust/article/details/70770148