1. 程式人生 > >機器學習知識點(四)最小二乘法Java實現

機器學習知識點(四)最小二乘法Java實現

最小二乘法(又稱最小平方法)是一種數學優化技術。它通過最小化誤差的平方和尋找資料的最佳函式匹配。利用最小二乘法可以簡便地求得未知的資料,並使得這些求得的資料與實際資料之間誤差的平方和為最小。最小二乘法還可用於曲線擬合。通過一元線性模型應用來理解最小二乘法。

監督學習任務中,預測離散結果的是分類任務,預測連續結果的是迴歸任務。在迴歸任務中,預測結果y和x的函式關係中,一元線性迴歸只包含一個屬性的,對應的線性關係;二元線性迴歸包含兩個屬性,對應的平面關係;d元線性迴歸就包括d個屬性,對應的超平面關係。

在一元線性迴歸任務中,給定資料集{(x1,y1),(x2,y2),…,(xn,yn)},有n個(xi

,yi)資料對,在座標中對應n個點,要擬合這n個點為一條直線的線性關係,自然是直線在n個點中間最好。但顯然有很多直線滿足,怎麼衡量呢?選擇怎樣的直線最好呢?標準是什麼?選擇最佳直線的標準是:使總的擬合誤差(即總殘差)達到最小。

1)用“殘差和最小”確定直線位置,存在相互抵消的問題。

2)用“殘差絕對值和最小”確定直線位置,但絕對值的計算比較麻煩。

3)最小二乘法的原則是以“殘差平方和最小”確定直線位置,用最小二乘法除了計算比較方便外,得到的估計量還具有優良特性、對異常值非常敏感。

綜上,我們選用最小二乘法的誤差平方和最小作為標準來選出一條直線作為n個點的擬合直線。最常用的是普通最小二乘法( Ordinary  Least Square,OLS):所選擇的迴歸模型應該使所有觀察值的殘差平方和達到最小。(Q為殘差平方和)- 即採用平方損失函式。

數學形式定義直線為:

f(xi)=yi=axi+b+ei;

其中,i∈[1,n],ei是樣本(xi,yi)的真實值yi=axi+b+ei和擬合值y’i= axi+b的誤差,即ei= yi-axi-b。


最小二乘法一元線性迴歸模型的Java實現,參考程式碼如下:

package sk.ml;

import java.text.DecimalFormat;
import java.util.Random;

/*
 * 功能:一元線性迴歸模型最小二乘法Java實現
 * 作者:Jason.F
 * 時間:2017-01-16
 */
public class LeastSquares {
	private final static int n=20;//隨機生成10個點的(x,y)
	public static void main(String[] args){
		//隨機生成20個座標點
		Random random = new Random();	
		double[] x=new double[n];
		double[] y=new double[n];
		for(int i=0;i<n;i++){//隨機生成n個double數
			x[i]=Double.valueOf(Math.floor(random.nextDouble()*(99-1)));
			y[i]=Double.valueOf(Math.floor(random.nextDouble()*(999-1)));
		}
        /* y = a x + b
		 * b = sum( y ) / n - a * sum( x ) / n
		 * a = ( n * sum( xy ) - sum( x ) * sum( y ) ) / ( n * sum( x^2 ) - sum(x) ^ 2 )
		 * */
		estimate(x, y, x.length );
	}
	/**
	  * 預測
	  * @param x,y,i
	  */
	public static void estimate( double[] x , double[] y , int i ) {
		double a = getA( x , y ) ;
		double b = getB( x , y , a ) ;
		//設定doubl字串輸出格式,不以科學計數法輸出	
		DecimalFormat df=new DecimalFormat("#,##0.00");//格式化設定
		System.out.println("y="+df.format(a)+"x"+"+"+df.format(b));
	}
	 
	 /**
	  * 計算 x的係數a
	  * @param x, y
	  * @return a
	  */
	 public static double getA( double[] x , double[] y ){
		 int n = x.length ;
		 return ( n * pSum( x , y ) - sum( x ) * sum( y ) )/ ( n * sqSum( x ) - Math.pow(sum(x), 2) ) ;
	 }
	 
	 /**
	  * 計算常量係數b
	  * @param x,y,a
	  * @returnb
	  */
	 public static double getB( double[] x , double[] y , double a ){
		 int n = x.length ;
		 return sum( y ) / n - a * sum( x ) / n ;
	 }
	 
	 /**
	  * 計算常量係數b
	  * @param x, y
	  * @return b
	  */
	 public static double getC( double[] x , double[] y ){
		 int n = x.length ;
		 double a = getA( x , y ) ;
		 return sum( y ) / n - a * sum( x ) / n ;
	 }
	 
	 //計算和值
	 private static double sum(double[] ds) {
		 double s = 0 ;
		 for( double d : ds ) s = s + d ;
		 return s ;
	 }
	 //計算開平方和值
	 private static double sqSum(double[] ds) {
		 double s = 0 ;
		 for( double d : ds ) s = s + Math.pow(d, 2) ;
		 return s ;
	 }
	 //計算x和y積的和值
	 private static double pSum( double[] x , double[] y ) {
		 double s = 0 ;
		 for( int i = 0 ; i < x.length ; i++ ) s = s + x[i] * y[i] ;
		 return s ;
	 }
}

隨機生成的一次執行結果如下:
y=-0.29x+541.23