在OpenCV中實現決策樹和隨機森林
目錄
1.決策樹
需要注意的點:
Ptr<TrainData> data_set = TrainData::loadFromCSV("mushroom.data",//檔名 0,//第0行略過 0, 1,//區間[0,1)為儲存了響應列 "cat[0-22]",//0-22行均為類別資料 ',',//資料間的分隔符號為"," '?');//丟失資料用"?"表示
1.資料型別有cat和ord之分,具體可以參閱統計資料定義:
https://zhidao.baidu.com/question/1964314134743418500.html
2.預設的響應列的格式(第2-第3行)是前閉後開;
3.分割訓練集和資料集時,資料集的順序會大幅度的影響決策樹的結果:
data_set->setTrainTestSplitRatio(0.90, false);
4.對於概率權重的設定,你可以理解為對識別某一類物體具有相對更高的準確率(請注意我的矩陣初始化方法);
float _priors[] = { 1.0,10.0 }; Mat priors(1, 2, CV_32F, _priors); dtree->setPriors(priors);//為所有的答案設定權重
5.在OpenCV3.0以上的版本使用決策樹與隨機森林所繼承的類都是RTrees,相比與DTrees而言,新的類RTrees能夠處理資料集中的缺失資料,建模的唯一區別就是在生成隨機森林時,需要設定樹的終止生成條件,預設是100棵樹:
forest_mushroom->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER
+ TermCriteria::EPS, 100, 0.01));//隨機森林的終止標準
6.模型的儲存與載入:
dtree->save("dtree_01.xml");//儲存 Ptr<RTrees> dtree = RTrees::load("dtree_01.xml");//載入訓練模型
7.如果你想用載入後的模型進行資料的分類和迴歸,請務必手動建立訓練集、測試集和驗證集;這是因為如果你的資料集裡面含有字母類別,那麼opencv會以預設的方式轉化為ASCALL碼並歸一化,在這種情況下如果依舊使用預設的方式載入驗證集,必然會在使用predict時程式崩潰。
決策樹的訓練程式碼:
毒蘑菇的資料集:https://github.com/oreillymedia/Learning-OpenCV-3_examples/tree/master/mushroom
#include<iostream>
#include<opencv2/opencv.hpp>
using namespace cv;
using namespace ml;
using namespace std;
//1.生成訓練集結構體物件指標
Ptr<TrainData> data_set = TrainData::loadFromCSV("mushroom.data",//檔名
0,//第0行略過
0,
1,//區間[0,1)為儲存了響應列
"cat[0-22]",//0-22行均為類別資料
',',//資料間的分隔符號為","
'?');//丟失資料用"?"表示
//2.驗證資料讀取的正確性
int n_samples = data_set->getNSamples();
if (n_samples == 0)
{
cerr << "Could not read file: mushroom.data" << endl;
exit(-1);
}
else
{
cout << "Read " << n_samples << " samples from mushroom.data" << endl;
}
//3.分割訓練集和測試集,比例為9:1,不打亂資料集的順序
data_set->setTrainTestSplitRatio(0.90, false);
int n_train_samples = data_set->getNTrainSamples();
int n_test_samples = data_set->getNTestSamples();
Mat trainMat = data_set->getTrainSamples();
//4.決策樹
//4.1 建立
Ptr<RTrees> dtree = RTrees::create();
//4.2 引數設定
dtree->setMaxDepth(8); //樹的最大深度
dtree->setMinSampleCount(10); //節點樣本數的最小值
dtree->setRegressionAccuracy(0.01f);
dtree->setUseSurrogates(false);//是否允許使用替代分叉點處理丟失的資料
dtree->setMaxCategories(15);//決策樹的最大預分類數量
dtree->setCVFolds(0);//如果 CVFolds>1 那麼就使用k-fold交叉修建決策樹 其中k=CVFolds
dtree->setUse1SERule(true);//True 表示使用更大力度的修剪,這會導致樹的規模更小,
但準確性更差,用於解決過擬合問題
dtree->setTruncatePrunedTree(true);//是否刪掉被減枝的部分
float _priors[] = { 1.0,10.0 };
Mat priors(1, 2, CV_32F, _priors);
dtree->setPriors(priors);//為所有的答案設定權重
//4.3 訓練
dtree->train(data_set);
//4.4 計算訓練誤差
Mat results;
float train_performance = dtree->calcError(data_set,
false,//true 表示使用測試集 false 表示使用訓練集
results);
//5 訓練集的結果分析
vector<String> names;
data_set->getNames(names);
Mat flags = data_set->getVarSymbolFlags();
Mat expected_responses = data_set->getResponses();
int good = 0, bad = 0, total = 0;
for (int i = 0; i < data_set->getNTrainSamples(); ++i)
{
float received = results.at<float>(i, 0);
float expected = expected_responses.at<float>(i, 0);
String r_str = names[(int)received];
String e_str = names[(int)expected];
if (received != expected)
{
bad++;
cout << "Expected: " << e_str << " ,got: " << r_str << endl;
}
else good++;
total++;
}
cout << "Correct answers: " << (float(good) / total) << "% " << endl;
cout << "Incorrect answers: " << (float(bad) / total) << "% " << endl;
//6 測試集的結果分析
float test_performance = dtree->calcError(data_set, true, results);
cout << "Performance on training data: " << train_performance << "%" << endl;
cout << "Performance on test data: " << test_performance << "%" << endl;
//儲存
dtree->save("dtree_01.xml");
資料的預測:
Ptr<RTrees> dtree = RTrees::load("dtree_01.xml");
Mat sample = (Mat_<float>(1, 22)
<< 2, 3, 4, 5, 6, 7, 8, 2, 9, 1, 8, 10, 10, 4, 4, 11, 4, 12, 11, 9, 10, 13);
float result = dtree->predict(sample);
2.隨機森林
使用隨機森林需要注意的點:
1.隨機森林的引數設定比較簡單,不要要考慮數的減枝等因素,但計算變數的重要性會需要額外的計算時間:
Ptr<RTrees> forest_mushroom = RTrees::create();
forest_mushroom->setMaxDepth(10); //樹的最大深度
forest_mushroom->setRegressionAccuracy(0.01f);//設定迴歸精度
forest_mushroom->setMinSampleCount(10);//節點的最小樣本數量
forest_mushroom->setMaxCategories(15);//最大預分類數
forest_mushroom->setCalculateVarImportance(true);//計算變數的重要性
forest_mushroom->setActiveVarCount(4);//樹節點隨機選擇特徵子集的大小
//終止標準
forest_mushroom->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER +
TermCriteria::EPS, 100, 0.01));
2.在資料的預測部分 與決策樹有些微的不同,隨機森林的預測有兩種形式,第一種是直接基於投票結果給出響應值:
Mat sample = testSample.row(i);
float r = forest_mushroom->predict(sample);
r = fabs((float)r - testResMat.at<float>(i)) <= FLT_EPSILON ? 1 : 0;
第二種是使用getVotes計算票選矩陣(注意:所計算出來的票選矩陣已經提取了最大公約數)我們可以利用票選結果來計算概率;
Mat sample = testSample.row(i);
Mat result;
forest_mushroom->getVotes(sample,result,0);
3.隨機森林的泛化能力很強,只需要較少的樣本就可以實現很高精度的分類,比如在8400多個蘑菇樣本中,我只使用了840樣本生成隨機森林,最後的分類準確率依然高達96%
4.載入、讀取以及驗證集的使用請參閱隨機森林部分
建模程式碼:
//建立蘑菇分類的隨機森林
//1.構建訓練集和測試集
Ptr<TrainData> data_set = TrainData::loadFromCSV("mushroom.data",//檔名
0,//第0行略過
0,
1,//區間[0,1)為儲存了響應列
"cat[0-22]",//0-22行均為類別資料
',',//資料間的分隔符號為","
'?');//丟失資料用"?"表示
//2.驗證資料讀取的正確性
int n_samples = data_set->getNSamples();
if (n_samples == 0)
{
cerr << "Could not read file: mushroom.data" << endl;
exit(-1);
}
else
{
cout << "Read " << n_samples << " samples from mushroom.data" << endl;
}
//3.分割訓練集和測試集,比例為9:1,打亂資料集的順序
data_set->setTrainTestSplitRatio(0.90, true);
int n_train_samples = data_set->getNTrainSamples();
int n_test_samples = data_set->getNTestSamples();
//4.隨機森林
Ptr<RTrees> forest_mushroom = RTrees::create();
forest_mushroom->setMaxDepth(10); //樹的最大深度
forest_mushroom->setRegressionAccuracy(0.01f);//設定迴歸精度
forest_mushroom->setMinSampleCount(10);//節點的最小樣本數量
forest_mushroom->setMaxCategories(15);//最大預分類數
forest_mushroom->setCalculateVarImportance(true);//計算變數的重要性
forest_mushroom->setActiveVarCount(4);//樹節點隨機選擇特徵子集的大小
forest_mushroom->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER +
TermCriteria::EPS, 100, 0.01));//終止標準
//訓練模型
forest_mushroom->train(data_set);
//計算訓練集和測試集的誤差
float correct_Train_answer = 0;
float correct_Test_answer = 0;
//1.訓練集
Mat trainSample = data_set->getTrainSamples();
Mat trainResMat = data_set->getTrainResponses();
for (int i = 0; i < trainSample.rows; i++)
{
Mat sample = trainSample.row(i);
float r = forest_mushroom->predict(sample);
r = fabs((float)r - trainResMat.at<float>(i)) <= FLT_EPSILON ? 1 : 0;
correct_Train_answer += r;
}
float r1 = correct_Train_answer / n_train_samples;
//2.測試集
Mat testSample = data_set->getTestSamples();
Mat testResMat = data_set->getTestResponses();
for (int i = 0; i < testSample.rows; i++)
{
Mat sample = testSample.row(i);
float r = forest_mushroom->predict(sample);
r = fabs((float)r - testResMat.at<float>(i)) <= FLT_EPSILON ? 1 : 0;
correct_Test_answer += r;
}
float r2 = correct_Test_answer / n_test_samples;
//3.輸出結果
cout << "trainSet Accuracy: " << r1* 100 << "%" << endl;
cout << "testSet Accuracy: " << r2 * 100 << "%" << endl;
//4.儲存模型
forest_mushroom->save("forest_mushroom.xml");
相關資料
1.OpenCV 3.4 官方ML庫手冊:
https://docs.opencv.org/3.4.0/dd/ded/group__ml.html
2.隨機森林:
https://blog.csdn.net/akadiao/article/details/79413713
https://www.cnblogs.com/hrlnw/p/3850459.html
https://blog.csdn.net/akadiao/article/details/79413713
https://blog.csdn.net/wishchin/article/details/78662797
https://blog.csdn.net/wishchin/article/details/78662797
3.OpenCV作者Gray的Github關於毒蘑菇的資料集資料:
https://github.com/oreillymedia/Learning-OpenCV-3_examples/tree/master/mushroom
4.統計資料的定義:
https://zhidao.baidu.com/question/1964314134743418500.html
5.OpenCV機器學習部落格:
https://www.cnblogs.com/denny402/p/5032232.html
6.ASCALL碼對照表
https://blog.csdn.net/u011930916/article/details/79623922
7.OpenCV中Mat物件使用全解:
https://blog.csdn.net/guyuealian/article/details/70159660
8.OpenCV資料型別的位數總結:
https://blog.csdn.net/lcgwust/article/details/70770148