1. 程式人生 > >3.1、隨機森林之隨機森林例項

3.1、隨機森林之隨機森林例項

例項一、用隨機森林對鳶尾花資料進行分類

#1、載入資料並檢視
data("iris")
summary(iris)
##   Sepal.Length    Sepal.Width     Petal.Length    Petal.Width   
##  Min.   :4.300   Min.   :2.000   Min.   :1.000   Min.   :0.100  
##  1st Qu.:5.100   1st Qu.:2.800   1st Qu.:1.600   1st Qu.:0.300  
##  Median :5.800   Median :3.000   Median :4.350   Median :1.300  
##  Mean   :5.843   Mean   :3.057   Mean   :3.758   Mean   :1.199  
##  3rd Qu.:6.400   3rd Qu.:3.300   3rd Qu.:5.100   3rd Qu.:1.800  
##  Max.   :7.900   Max.   :4.400   Max.   :6.900   Max.   :2.500  
##        Species  
##  setosa    :50  
##  versicolor:50  
##  virginica :50  
##                 
##                 
## 
str(iris)
## 'data.frame':    150 obs. of  5 variables:
##  $ Sepal.Length: num  5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
##  $ Sepal.Width : num  3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
##  $ Petal.Length: num  1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
##  $ Petal.Width : num  0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
##  $ Species     : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
#2、建立訓練集和測試集資料
set.seed(2001)
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
## Warning: package 'ggplot2' was built under R version 3.2.3
index <- createDataPartition(iris$Species, p=0.7, list=F)
train_iris <- iris[index, ]
test_iris <- iris[-index, ]
#3、建模 library(randomForest)
## randomForest 4.6-12
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
## 
##     margin
model_iris <- randomForest(Species~., data=train_iris, ntree=50, nPerm=10, mtry=3, proximity=T, importance=T)

#4、模型評估
model_iris
## 
## Call:
##  randomForest(formula = Species ~ ., data = train_iris, ntree = 50,      nPerm = 10, mtry = 3, proximity = T, importance = T) 
##                Type of random forest: classification
##                      Number of trees: 50
## No. of variables tried at each split: 3
## 
##         OOB estimate of  error rate: 4.76%
## Confusion matrix:
##            setosa versicolor virginica class.error
## setosa         35          0         0  0.00000000
## versicolor      0         32         3  0.08571429
## virginica       0          2        33  0.05714286
str(model_iris)
## List of 19
##  $ call           : language randomForest(formula = Species ~ ., data = train_iris, ntree = 50,      nPerm = 10, mtry = 3, proximity = T, importance = T)
##  $ type           : chr "classification"
##  $ predicted      : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
##   ..- attr(*, "names")= chr [1:105] "5" "7" "8" "11" ...
##  $ err.rate       : num [1:50, 1:4] 0.0513 0.0758 0.0741 0.0435 0.0505 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ : NULL
##   .. ..$ : chr [1:4] "OOB" "setosa" "versicolor" "virginica"
##  $ confusion      : num [1:3, 1:4] 35 0 0 0 32 2 0 3 33 0 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ : chr [1:3] "setosa" "versicolor" "virginica"
##   .. ..$ : chr [1:4] "setosa" "versicolor" "virginica" "class.error"
##  $ votes          : matrix [1:105, 1:3] 1 1 1 1 1 1 1 1 1 1 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ : chr [1:105] "5" "7" "8" "11" ...
##   .. ..$ : chr [1:3] "setosa" "versicolor" "virginica"
##   ..- attr(*, "class")= chr [1:2] "matrix" "votes"
##  $ oob.times      : num [1:105] 15 23 22 16 17 11 20 20 17 19 ...
##  $ classes        : chr [1:3] "setosa" "versicolor" "virginica"
##  $ importance     : num [1:4, 1:5] 0 0 0.3417 0.34918 -0.00518 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ : chr [1:4] "Sepal.Length" "Sepal.Width" "Petal.Length" "Petal.Width"
##   .. ..$ : chr [1:5] "setosa" "versicolor" "virginica" "MeanDecreaseAccuracy" ...
##  $ importanceSD   : num [1:4, 1:4] 0 0 0.04564 0.04711 0.00395 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ : chr [1:4] "Sepal.Length" "Sepal.Width" "Petal.Length" "Petal.Width"
##   .. ..$ : chr [1:4] "setosa" "versicolor" "virginica" "MeanDecreaseAccuracy"
##  $ localImportance: NULL
##  $ proximity      : num [1:105, 1:105] 1 1 1 1 1 1 1 1 1 1 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ : chr [1:105] "5" "7" "8" "11" ...
##   .. ..$ : chr [1:105] "5" "7" "8" "11" ...
##  $ ntree          : num 50
##  $ mtry           : num 3
##  $ forest         :List of 14
##   ..$ ndbigtree : int [1:50] 11 5 9 9 9 9 9 11 11 9 ...
##   ..$ nodestatus: int [1:17, 1:50] 1 -1 1 1 1 -1 -1 1 -1 -1 ...
##   ..$ bestvar   : int [1:17, 1:50] 4 0 4 3 3 0 0 1 0 0 ...
##   ..$ treemap   : int [1:17, 1:2, 1:50] 2 0 4 6 8 0 0 10 0 0 ...
##   ..$ nodepred  : int [1:17, 1:50] 0 1 0 0 0 2 3 0 3 2 ...
##   ..$ xbestsplit: num [1:17, 1:50] 0.8 0 1.65 5.25 4.85 0 0 6.05 0 0 ...
##   ..$ pid       : num [1:3] 1 1 1
##   ..$ cutoff    : num [1:3] 0.333 0.333 0.333
##   ..$ ncat      : Named int [1:4] 1 1 1 1
##   .. ..- attr(*, "names")= chr [1:4] "Sepal.Length" "Sepal.Width" "Petal.Length" "Petal.Width"
##   ..$ maxcat    : int 1
##   ..$ nrnodes   : int 17
##   ..$ ntree     : num 50
##   ..$ nclass    : int 3
##   ..$ xlevels   :List of 4
##   .. ..$ Sepal.Length: num 0
##   .. ..$ Sepal.Width : num 0
##   .. ..$ Petal.Length: num 0
##   .. ..$ Petal.Width : num 0
##  $ y              : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
##   ..- attr(*, "names")= chr [1:105] "5" "7" "8" "11" ...
##  $ test           : NULL
##  $ inbag          : NULL
##  $ terms          :Classes 'terms', 'formula' length 3 Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
##   .. ..- attr(*, "variables")= language list(Species, Sepal.Length, Sepal.Width, Petal.Length, Petal.Width)
##   .. ..- attr(*, "factors")= int [1:5, 1:4] 0 1 0 0 0 0 0 1 0 0 ...
##   .. .. ..- attr(*, "dimnames")=List of 2
##   .. .. .. ..$ : chr [1:5] "Species" "Sepal.Length" "Sepal.Width" "Petal.Length" ...
##   .. .. .. ..$ : chr [1:4] "Sepal.Length" "Sepal.Width" "Petal.Length" "Petal.Width"
##   .. ..- attr(*, "term.labels")= chr [1:4] "Sepal.Length" "Sepal.Width" "Petal.Length" "Petal.Width"
##   .. ..- attr(*, "order")= int [1:4] 1 1 1 1
##   .. ..- attr(*, "intercept")= num 0
##   .. ..- attr(*, "response")= int 1
##   .. ..- attr(*, ".Environment")=<environment: R_GlobalEnv> 
##   .. ..- attr(*, "predvars")= language list(Species, Sepal.Length, Sepal.Width, Petal.Length, Petal.Width)
##   .. ..- attr(*, "dataClasses")= Named chr [1:5] "factor" "numeric" "numeric" "numeric" ...
##   .. .. ..- attr(*, "names")= chr [1:5] "Species" "Sepal.Length" "Sepal.Width" "Petal.Length" ...
##  - attr(*, "class")= chr [1:2] "randomForest.formula" "randomForest"
pred <- predict(model_iris, train_iris)
mean(pred==train_iris[, 5])
## [1] 1
#5、預測
pred_iris <- predict(model_iris, test_iris)
table(pred_iris, test_iris[, 5])
##             
## pred_iris    setosa versicolor virginica
##   setosa         15          0         0
##   versicolor      0         13         2
##   virginica       0          2        13
mean(pred_iris==test_iris[, 5])
## [1] 0.9111111
library(gmodels)
CrossTable(pred_iris, test_iris[, 5])
## 
##  
##    Cell Contents
## |-------------------------|
## |                       N |
## | Chi-square contribution |
## |           N / Row Total |
## |           N / Col Total |
## |         N / Table Total |
## |-------------------------|
## 
##  
## Total Observations in Table:  45 
## 
##  
##              | test_iris[, 5] 
##    pred_iris |     setosa | versicolor |  virginica |  Row Total | 
## -------------|------------|------------|------------|------------|
##       setosa |         15 |          0 |          0 |         15 | 
##              |     20.000 |      5.000 |      5.000 |            | 
##              |      1.000 |      0.000 |      0.000 |      0.333 | 
##              |      1.000 |      0.000 |      0.000 |            | 
##              |      0.333 |      0.000 |      0.000 |            | 
## -------------|------------|------------|------------|------------|
##   versicolor |          0 |         13 |          2 |         15 | 
##              |      5.000 |     12.800 |      1.800 |            | 
##              |      0.000 |      0.867 |      0.133 |      0.333 | 
##              |      0.000 |      0.867 |      0.133 |            | 
##              |      0.000 |      0.289 |      0.044 |            | 
## -------------|------------|------------|------------|------------|
##    virginica |          0 |          2 |         13 |         15 | 
##              |      5.000 |      1.800 |     12.800 |            | 
##              |      0.000 |      0.133 |      0.867 |      0.333 | 
##              |      0.000 |      0.133 |      0.867 |            | 
##              |      0.000 |      0.044 |      0.289 |            | 
## -------------|------------|------------|------------|------------|
## Column Total |         15 |         15 |         15 |         45 | 
##              |      0.333 |      0.333 |      0.333 |            | 
## -------------|------------|------------|------------|------------|
## 
## 

例項二、用坦泰尼克號乘客是否存活資料應用到隨機森林演算法中

在隨機森林演算法的函式randomForest()中有兩個非常重要的引數,而這兩個引數又將影響模型的準確性,它們分別是mtry和ntree。一般對mtry的選擇是逐一嘗試,直到找到比較理想的值,ntree的選擇可通過圖形大致判斷模型內誤差穩定時的值。 randomForest包中的randomForest(formula, data, ntree, nPerm, mtry, proximity, importace)函式:隨機森林分類與迴歸。ntree表示生成決策樹的數目(不應設定太小,預設為 500);nPerm表示計算importance時的重複次數,數量大於1給出了比較穩定的估計,但不是很有效(目前只實現了迴歸);mtry表示選擇的分裂屬性的個數;proximity表示是否生成鄰近矩陣,為T表示生成鄰近矩陣;importance表示輸出分裂屬性的重要性。

下面使用坦泰尼克號乘客是否存活資料應用到隨機森林演算法中,看看模型的準確性如何。

#1、載入資料並檢視:同時讀取訓練樣本和測試樣本集
train <- read.table("F:\\R\\Rworkspace\\RandomForest/train.csv", header=T, sep=",")
test <- read.table("F:\\R\\Rworkspace\\RandomForest/test.csv", header=T, sep=",")
#注意:訓練集和測試集資料來自不同的資料集,一定要注意測試集和訓練集的factor的levels相同,否則,在利用訓練集訓練的模型對測試集進行預測時,會報錯!!!

str(train)
## 'data.frame':    891 obs. of  8 variables:
##  $ Survived: int  0 1 1 1 0 0 0 0 1 1 ...
##  $ Pclass  : int  3 1 3 1 3 3 1 3 3 2 ...
##  $ Sex     : Factor w/ 2 levels "female","male": 2 1 1 1 2 2 2 2 1 1 ...
##  $ Age     : num  22 38 26 35 35 NA 54 2 27 14 ...
##  $ SibSp   : int  1 1 0 1 0 0 0 3 0 1 ...
##  $ Parch   : int  0 0 0 0 0 0 0 1 2 0 ...
##  $ Fare    : num  7.25 71.28 7.92 53.1 8.05 ...
##  $ Embarked: Factor w/ 4 levels "","C","Q","S": 4 2 4 4 4 3 4 4 4 2 ...
str(test)
## 'data.frame':    418 obs. of  7 variables:
##  $ Pclass  : int  3 3 2 3 3 3 3 2 3 3 ...
##  $ Sex     : Factor w/ 2 levels "female","male": 2 1 2 2 1 2 1 2 1 2 ...
##  $ Age     : num  34.5 47 62 27 22 14 30 26 18 21 ...
##  $ SibSp   : int  0 1 0 0 1 0 0 1 0 2 ...
##  $ Parch   : int  0 0 0 0 1 0 0 1 0 0 ...
##  $ Fare    : num  7.83 7 9.69 8.66 12.29 ...
##  $ Embarked: Factor w/ 3 levels "C","Q","S": 2 3 2 3 3 3 2 3 1 3 ...
#從上可知:訓練集資料共891條記錄,8個變數,Embarked因子水平為4;測試集資料共418條記錄,7個變數,Embarked因子水平為3;訓練集中存在缺失資料;Survived因變數為數字型別,測試集資料無因變數

#2、資料清洗
#1)調整測試集與訓練基地因子水平
levels(train$Embarked)
## [1] ""  "C" "Q" "S"
levels(test$Embarked)
## [1] "C" "Q" "S"
levels(test$Embarked) <- levels(train$Embarked)

#2)把因變數轉化為因子型別
train$Survived <- as.factor(train$Survived)

#3)使用rfImpute()函式補齊訓練集的缺失值NA
library(randomForest)
train_impute <- rfImpute(Survived~., data=train)
## ntree      OOB      1      2
##   300:  16.39%  7.83% 30.12%
## ntree      OOB      1      2
##   300:  16.50%  8.93% 28.65%
## ntree      OOB      1      2
##   300:  16.72%  8.74% 29.53%
## ntree      OOB      1      2
##   300:  16.50%  8.56% 29.24%
## ntree      OOB      1      2
##   300:  17.28%  9.47% 29.82%
#4)補齊測試集的缺失值:對待測樣本進行預測,發現待測樣本中存在缺失值,這裡使用多重插補法將缺失值補齊
summary(test)
##      Pclass          Sex           Age            SibSp       
##  Min.   :1.000   female:152   Min.   : 0.17   Min.   :0.0000  
##  1st Qu.:1.000   male  :266   1st Qu.:21.00   1st Qu.:0.0000  
##  Median :3.000                Median :27.00   Median :0.0000  
##  Mean   :2.266                Mean   :30.27   Mean   :0.4474  
##  3rd Qu.:3.000                3rd Qu.:39.00   3rd Qu.:1.0000  
##  Max.   :3.000                Max.   :76.00   Max.   :8.0000  
##                               NA's   :86                      
##      Parch             Fare         Embarked
##  Min.   :0.0000   Min.   :  0.000    :102   
##  1st Qu.:0.0000   1st Qu.:  7.896   C: 46   
##  Median :0.0000   Median : 14.454   Q:270   
##  Mean   :0.3923   Mean   : 35.627   S:  0   
##  3rd Qu.:0.0000   3rd Qu.: 31.500           
##  Max.   :9.0000   Max.   :512.329           
##                   NA's   :1
#可是看出測試集資料存在缺失值NA,Age和Fare的資料有NA

#多重插補法填充缺失值:
library(mice)
## Loading required package: Rcpp
## mice 2.25 2015-11-09
imput <- mice(data=test, m=10)
## 
##  iter imp variable
##   1   1  Age  Fare
##   1   2  Age  Fare
##   1   3  Age  Fare
##   1   4  Age  Fare
##   1   5  Age  Fare
##   1   6  Age  Fare
##   1   7  Age  Fare
##   1   8  Age  Fare
##   1   9  Age  Fare
##   1   10  Age  Fare
##   2   1  Age  Fare
##   2   2  Age  Fare
##   2   3  Age  Fare
##   2   4  Age  Fare
##   2   5  Age  Fare
##   2   6  Age  Fare
##   2   7  Age  Fare
##   2   8  Age  Fare
##   2   9  Age  Fare
##   2   10  Age  Fare
##   3   1  Age  Fare
##   3   2  Age  Fare
##   3   3  Age  Fare
##   3   4  Age  Fare
##   3   5  Age  Fare
##   3   6  Age  Fare
##   3   7  Age  Fare
##   3   8  Age  Fare
##   3   9  Age  Fare
##   3   10  Age  Fare
##   4   1  Age  Fare
##   4   2  Age  Fare
##   4   3  Age  Fare
##   4   4  Age  Fare
##   4   5  Age  Fare
##   4   6  Age  Fare
##   4   7  Age  Fare
##   4   8  Age  Fare
##   4   9  Age  Fare
##   4   10  Age  Fare
##   5   1  Age  Fare
##   5   2  Age  Fare
##   5   3  Age  Fare
##   5   4  Age  Fare
##   5   5  Age  Fare
##   5   6  Age  Fare
##   5   7  Age  Fare
##   5   8  Age  Fare
##   5   9  Age  Fare
##   5   10  Age  Fare
Age <- data.frame(Age=apply(imput$imp$Age, 1, mean))
Fare <- data.frame(Fare=apply(imput$imp$Fare, 1, mean))

#新增行標號:
test$Id <- row.names(test)
Age$Id <- row.names(Age)
Fare$Id <- row.names(Fare)

#替換缺失值:
test[test$Id %in% Age$Id, 'Age'] <- Age$Age
test[test$Id %in% Fare$Id, 'Fare'] <- Fare$Fare
summary(test)
##      Pclass          Sex           Age            SibSp       
##  Min.   :1.000   female:152   Min.   : 0.17   Min.   :0.0000  
##  1st Qu.:1.000   male  :266   1st Qu.:22.00   1st Qu.:0.0000  
##  Median :3.000                Median :26.19   Median :0.0000  
##  Mean   :2.266                Mean   :29.41   Mean   :0.4474  
##  3rd Qu.:3.000                3rd Qu.:36.65   3rd Qu.:1.0000  
##  Max.   :3.000                Max.   :76.00   Max.   :8.0000  
##      Parch             Fare         Embarked      Id           
##  Min.   :0.0000   Min.   :  0.000    :102    Length:418        
##  1st Qu.:0.0000   1st Qu.:  7.896   C: 46    Class :character  
##  Median :0.0000   Median : 14.454   Q:270    Mode  :character  
##  Mean   :0.3923   Mean   : 35.583   S:  0                      
##  3rd Qu.:0.0000   3rd Qu.: 31.472                              
##  Max.   :9.0000   Max.   :512.329
#從上可知:測試資料集中已經沒有了NA值。

#3、選著隨機森林的mtry和ntree值
#1)選著mtry
(n <- length(names(train)))
## [1] 8
library(randomForest)
for(i in 1:n) {
  model <- randomForest(Survived~., data=train_impute,  mtry=i)
  err <- mean(model$err.rate)
  print(err)
}
## [1] 0.2100028
## [1] 0.1889116
## [1] 0.1776607
## [1] 0.1902606
## [1] 0.1960938
## [1] 0.1953451
## [1] 0.1951303
## [1] 0.2018745
#從上可知:mtry=2或者mtry=3時,模型內評價誤差最小,故確定引數mtry=2或者mtry=3

#2)選著ntree
set.seed(2002)
model <- randomForest(Survived~., data=train_impute, mtry=2, ntree=1000)
plot(model)

#從上圖可知:ntree在400左右時,模型內誤差基本穩定,故取ntree=400

#4、建模
model_fit <- randomForest(Survived~., data=train_impute, mtry=2, ntree=400, importance=T)

#5、模型評估
model_fit
## 
## Call:
##  randomForest(formula = Survived ~ ., data = train_impute, mtry = 2,      ntree = 400, importance = T) 
##                Type of random forest: classification
##                      Number of trees: 400
## No. of variables tried at each split: 2
## 
##         OOB estimate of  error rate: 16.61%
## Confusion matrix:
##     0   1 class.error
## 0 500  49  0.08925319
## 1  99 243  0.28947368
#檢視變數的重要性
(importance <- importance(x=model_fit))
##                  0         1 MeanDecreaseAccuracy MeanDecreaseGini
## Pclass   16.766454 28.241508             32.16125         33.15984
## Sex      46.578191 76.145306             72.42624        100.74843
## Age      19.882605 24.586274             30.52032         60.85186
## SibSp    19.070707  2.834303             18.95690         16.11720
## Parch    10.366140  8.380559             13.18282         12.28725
## Fare     18.649672 20.967558             29.43262         66.31489
## Embarked  7.904436 11.479919             14.18780         12.68924
#繪製變數的重要性圖
varImpPlot(model_fit)

#從上圖可知:模型中乘客的性別最為重要,接下來的是Pclass,age,Fare和Fare,age,Pclass。

#6、預測
#1)對訓練集資料預測:
train_pred <- predict(model_fit, train_impute)
mean(train_pred==train_impute$Survived)
## [1] 0.9135802
table(train_pred, train_impute$Survived)
##           
## train_pred   0   1
##          0 535  63
##          1  14 279
#模型的預測精度在90%以上

#2)對測試集資料預測:
test_pred <- predict(model_fit, test[, 1:7])
head(test_pred)
## 1 2 3 4 5 6 
## 0 0 0 0 1 0 
## Levels: 0 1