1. 程式人生 > >Multinomial Logit Model (MNL) 模型R語言nnet包multinom函式實現例項

Multinomial Logit Model (MNL) 模型R語言nnet包multinom函式實現例項

最近做專案涉及到要使用multinomial logit model (MNL) 模型。看了一堆文獻講mnl, 但是沒有給什麼具體能上手的例項,就算有也是一筆帶過,打算找一些使用R 語言來實現mnl模型的例子,在模仿和實踐中慢慢理解。

Multinomial Logit Model又有很多其它說法,諸如Multinomial Logistic Regression等等。

本文的例項來自兩篇文章。

第一篇  R Data Analysis Examples: Multinomial Logistic Regression

第一篇是UCLA的idre機構網站中,關於R語言實現 Multinomial Logistic Regression 的教程

Multinomial logistic regression被用於輸出結果為 nominal variables 的建模。

本文使用了一下的包,請確保你能載入這些包,如果你沒有安裝,可以使用語句 :install.packages("packagename"), 或者如果你使用的包的版本太低,可以使用語句: update.packages()  .

require(foreign)
require(nnet)
require(ggplot2)
require(reshape2)
Version info: Code for this page was tested in R version 3.1.1 (2014-07-10)

On: 2015-12-17
With: reshape2 1.4.1; ggplot2 1.0.1; nnet 7.3-10; foreign 0.8-65; knitr 1.10.5


Multinomial Logistic Regression的例子

例1: 人們的職業選擇結果可能會被父母的職業和他們自己的教育水平所影響。我們可以研究某個人的職業選擇和他的教育水平、父母的職業之間的關係。而人們所選擇的職業多種多樣,不是隻有一種或者兩種。

例2:一個生物學家啃呢個會對短吻鱷所選擇的食物感興趣。成年的短吻鱷可能與幼年的短吻鱷的食物偏好不同。所以我們這裡的有各種各樣的食物作為選擇結果,該結果是被短吻鱷的形體大小和環境變數所影響。

例3:進入高等教育的學生對於選擇什麼樣的學習專案型別有三種選擇:普通專案,針對工作的專案以及學術型的專案。他們的選擇可能被他們的寫作成績和社會經濟地位影響。

資料描述

我們的資料分析例子使用的是第三個例子,使用 hsbdemo 資料。首先是先讀入資料。

ml <- read.dta("http://www.ats.ucla.edu/stat/data/hsbdemo.dta")
 該資料包括200個學生的選擇的專案型別(prog, 三種類型 categorical variable), 他們的社會地位(ses 三種地位 categorical variable),寫作分數(write, a continuous variable)。讀完資料後,我們可以使用一些語句來對我們的資料建立一個初步的概念和感覺。
with(ml, table(ses, prog))
##         prog
## ses      general academic vocation
##   low         16       19       12
##   middle      20       44       31
##   high         9       42        7

從結果中可以看出,社會經濟地位中等的學生最多,低等的最少。社會經濟地位高的學生中,絕大多數都是選擇學術型的program 。

with(ml, do.call(rbind, tapply(write, prog, function(x) c(M = mean(x), SD = sd(x)))))
##             M   SD
## general  51.3 9.40
## academic 56.3 7.94
## vocation 46.8 9.32
從結果中可以看出,選擇學術型專案的學生的寫作成績平均分最高,且波動最小。 選擇職業型專案的學生寫作成績平均分最低。

Multinomial Logistic Regression 方法

以下的方法中我們使用的nnet 包中的multinom 函式。其實在R語言的其它包中也有其它函式可以實現MNL方法(諸如 mlogit)。 我們選擇multinom 函式的原因是,它不需要將資料reshape(而mlogit就需要)。

在執行我們的模型之前,選擇一個合適的參照組(a reference group)非常重要。 我們可以使用relevel 來調換outcome variable 的等級順序。值得注意的是, multinom 包不包含對於迴歸係數的p-value的計算。

ml$prog2 <- relevel(ml$prog, ref = "academic")
test <- multinom(prog2 ~ ses + write, data = ml)
## # weights:  15 (8 variable)
## initial  value 219.722458 
## iter  10 value 179.982880
## final  value 179.981726 
## converged
summary(test)
## Call:
## multinom(formula = prog2 ~ ses + write, data = ml)
## 
## Coefficients:
##          (Intercept) sesmiddle seshigh   write
## general         2.85    -0.533  -1.163 -0.0579
## vocation        5.22     0.291  -0.983 -0.1136
## 
## Std. Errors:
##          (Intercept) sesmiddle seshigh  write
## general         1.17     0.444   0.514 0.0214
## vocation        1.16     0.476   0.596 0.0222
## 
## Residual Deviance: 360 
## AIC: 376
z <- summary(test)$coefficients/summary(test)$standard.errors
z
##          (Intercept) sesmiddle seshigh write
## general         2.45    -1.202   -2.26 -2.71
## vocation        4.48     0.612   -1.65 -5.11
#2-tailed z test
p <- (1 - pnorm(abs(z), 0, 1))*2
p
##          (Intercept) sesmiddle seshigh    write
## general     1.45e-02     0.229  0.0237 6.82e-03
## vocation    7.30e-06     0.541  0.0989 3.18e-07

我們首先可以看到,即使我們已經將multinom執行賦值給了test,模型執行後仍然得到了一些結果。執行結果包括 some iteration history 和最終的負log-likelihood 179.981726. 這一結果乘以2後即為模型summary結果中的 Residual Deviance:360, 它可以用來和 nested models 來比較,但是本例中我們不做比較。

模型summary 後的結果一塊是係數,一塊是  standard errors。每一個塊中,都有一行與模型等式對應。在係數塊中,第一行是general program 與academic program 的比較,第二行代表 vocation program 與 academic program 的比較(academic program 是我們的參考組reference group。 如果我們把第一行的係數記為b1_, 第二行的係數記為b2_, 我們可以寫出我們的模型等式。

ln(P(prog=general)P(prog=academic))=b10+b11(ses=2)+b12(ses=3)+b13write ln(P(prog=vocation)P(prog=academic))=b20+b21(ses=2)+b22(ses=3)+b23write

一種選擇結果類別的概率與基準選擇的概率的比被稱為相對危險度(relative risk), 有時候只是為了描述迴歸引數,我們也叫做odds. 相對危險度是對線性方程等號右邊部分的取冪,取冪後的迴歸係數就是自變數每單位變化的相對危險度。下面我們開對模型的係數取冪來觀察是如何變化的。
## extract the coefficients from the model and exponentiate
exp(coef(test))
##          (Intercept) sesmiddle seshigh write
## general         17.3     0.587   0.313 0.944
## vocation       184.6     1.338   0.374 0.893

  • “write”變數增加一個單位,普通專案vs學術專案的相對危險風險比(the relative risk ratio)是0.9437
  • “ses”社會地位變數從1變為3時,普通專案VS學術專案的相對風險比是0.3126
你也可以使用預測的概率來幫助你理解這個模型。你可以使用 fitted 函式來得到模型的擬合值(估計值)。(在這裡和原資料比對了一下,感覺不是很精確呀。)
head(pp <- fitted(test))
##   academic general vocation
## 1    0.148   0.338    0.513
## 2    0.120   0.181    0.699
## 3    0.419   0.237    0.345
## 4    0.173   0.351    0.476
## 5    0.100   0.169    0.731
## 6    0.353   0.238    0.409

然後,如果你想檢測我們變數中某個變數的改變對預測結果的影響,可以建立一個小的資料組,改變其中一個變數,其它變數保持不變。首先,我們讓"write" 變數保持不變,檢測社會地位的改變對預測值的影響。
dses <- data.frame(ses = c("low", "middle", "high"),
  write = mean(ml$write))
predict(test, newdata = dses, "probs")
##   academic general vocation
## 1    0.440   0.358    0.202
## 2    0.478   0.228    0.294
## 3    0.701   0.178    0.121
另一種理解模型的方式是在三種不同的社會地位下,連續改變"write"值,並對該社會地位中的預測值取平均。
dwrite <- data.frame(ses = rep(c("low", "middle", "high"), each = 41),
  write = rep(c(30:70), 3))

## store the predicted probabilities for each value of ses and write
pp.write <- cbind(dwrite, predict(test, newdata = dwrite, type = "probs", se = TRUE))

## calculate the mean probabilities within each level of ses
by(pp.write[, 3:5], pp.write$ses, colMeans)
## pp.write$ses: high
## academic  general vocation 
##    0.616    0.181    0.203 
## ------------------------------------------------------ 
## pp.write$ses: low
## academic  general vocation 
##    0.397    0.328    0.275 
## ------------------------------------------------------ 
## pp.write$ses: middle
## academic  general vocation 
##    0.426    0.201    0.373

有時候,一組圖能過很好地表達大量的資訊。使用我們上面的“pp.write” 物件,我們可以繪製在不同社會地位下,預測值與寫作分數的關係的圖。
## melt data set to long for ggplot2
lpp <- melt(pp.write, id.vars = c("ses", "write"), value.name = "probability")
head(lpp) # view first few rows
##   ses write variable probability
## 1 low    30 academic      0.0984
## 2 low    31 academic      0.1072
## 3 low    32 academic      0.1165
## 4 low    33 academic      0.1265
## 5 low    34 academic      0.1370
## 6 low    35 academic      0.1483
## plot predicted probabilities across write values for
## each level of ses facetted by program type
ggplot(lpp, aes(x = write, y = probability, colour = ses)) +
  geom_line() +
  facet_grid(variable ~ ., scales="free")
Predicted probabilities plot