1. 程式人生 > >一元線性迴歸模型與最小二乘法及其C++實現

一元線性迴歸模型與最小二乘法及其C++實現

        監督學習中,如果預測的變數是離散的,我們稱其為分類(如決策樹,支援向量機等),如果預測的變數是連續的,我們稱其為迴歸。迴歸分析中,如果只包括一個自變數和一個因變數,且二者的關係可用一條直線近似表示,這種迴歸分析稱為一元線性迴歸分析。如果迴歸分析中包括兩個或兩個以上的自變數,且因變數和自變數之間是線性關係,則稱為多元線性迴歸分析。對於二維空間線性是一條直線;對於三維空間線性是一個平面,對於多維空間線性是一個超平面…這裡,談一談最簡單的一元線性迴歸模型。

1.一元線性迴歸模型

模型如下:



總體迴歸函式中Y與X的關係可是線性的,也可是非線性的。對線性迴歸模型的“線性”有兩種解釋:

      (1)就變數而言是線性的,Y的條件均值是 X的線性函式

     (2)就引數而言是線性的,Y的條件均值是引數的線性函式

線性迴歸模型主要指就引數而言是“線性”,因為只要對引數而言是線性的,都可以用類似的方法估計其引數。

2.引數估計——最小二乘法

        對於一元線性迴歸模型, 假設從總體中獲取了n組觀察值(X1,Y1),(X2,Y2), …,(Xn,Yn)。對於平面中的這n個點,可以使用無數條曲線來擬合。要求樣本回歸函式儘可能好地擬合這組值。綜合起來看,這條直線處於樣本資料的中心位置最合理。 選擇最佳擬合曲線的標準可以確定為:使總的擬合誤差(即總殘差)達到最小。有以下三個標準可以選擇:

        (1)用“殘差和最小”確定直線位置是一個途徑。但很快發現計算“殘差和”存在相互抵消的問題。
        (2)用“殘差絕對值和最小”確定直線位置也是一個途徑。但絕對值的計算比較麻煩。
        (3)最小二乘法的原則是以“殘差平方和最小”確定直線位置。用最小二乘法除了計算比較方便外,得到的估計量還具有優良特性。這種方法對異常值非常敏感。

        最常用的是普通最小二乘法( Ordinary  Least Square,OLS):所選擇的迴歸模型應該使所有觀察值的殘差平方和達到最小。(Q為殘差平方和)

樣本回歸模型:


殘差平方和:


則通過Q最小確定這條直線,即確定

,以為變數,把它們看作是Q的函式,就變成了一個求極值的問題,可以通過求導數得到。求Q對兩個待估引數的偏導數:


解得:


3.最小二乘法c++實現

[cpp] view plain copy print?
  1. #include<iostream>
  2. #include<fstream>
  3. #include<vector>
  4. usingnamespace std;  
  5. class LeastSquare{  
  6.     double a, b;  
  7. public:  
  8.     LeastSquare(const vector<double>& x, const vector<double>& y)  
  9.     {  
  10.         double t1=0, t2=0, t3=0, t4=0;  
  11.         for(int i=0; i<x.size(); ++i)  
  12.         {  
  13.             t1 += x[i]*x[i];  
  14.             t2 += x[i];  
  15.             t3 += x[i]*y[i];  
  16.             t4 += y[i];  
  17.         }  
  18.         a = (t3*x.size() - t2*t4) / (t1*x.size() - t2*t2);  
  19.         //b = (t4 - a*t2) / x.size();
  20.         b = (t1*t4 - t2*t3) / (t1*x.size() - t2*t2);  
  21.     }  
  22.     double getY(constdouble x) const
  23.     {  
  24.         return a*x + b;  
  25.     }  
  26.     void print() const
  27.     {  
  28.         cout<<”y = ”<<a<<“x + ”<<b<<“\n”;  
  29.     }  
  30. };  
  31. int main(int argc, char *argv[])  
  32. {  
  33.     if(argc != 2)  
  34.     {  
  35.         cout<<”Usage: DataFile.txt”<<endl;  
  36.         return -1;  
  37.     }  
  38.     else
  39.     {  
  40.         vector<double> x;  
  41.         ifstream in(argv[1]);  
  42.         for(double d; in>>d; )  
  43.             x.push_back(d);  
  44.         int sz = x.size();  
  45.         vector<double> y(x.begin()+sz/2, x.end());  
  46.         x.resize(sz/2);  
  47.         LeastSquare ls(x, y);  
  48.         ls.print();  
  49.         cout<<”Input x:\n”;  
  50.         double x0;  
  51.         while(cin>>x0)  
  52.         {  
  53.             cout<<”y = ”<<ls.getY(x0)<<endl;  
  54.             cout<<”Input x:\n”;  
  55.         }  
  56.     }  
  57. }