Java 使用 CommonsMath3 的線性和非線性擬合例項,帶效果圖

例子檢視

  • GitHub
  • Gitee
  • 執行src/main/java/org/wfw/chart/Main.java 即可檢視效果
  • src/main/java/org/wfw/math 包下是簡單的使用

版本說明

  • JDK:1.8
  • commons-math:3.6.1

一些基礎知識

  • 線性:兩個變數之間存在一次方函式關係,就稱它們之間存線上性關係。也就是如下的函式:
\[f(x)=kx+b
\]
  • 非線性:除了線性其他的都是非線性,例如:
\[f(x)=e^x
\]
  • 矩陣:矩陣(Matrix)是一個按照長方陣列排列的複數或實數集合,可以理解為平面或者空間的座標點。

    看大佬怎麼說之>> B站-線性代數的本質 - 系列合集

  • 微分、積分:互為逆過程,一句話概括,微分就是求導,求某個點的極小變化量的斜率。積分是求一些列變化點的和,幾何意義是面積

    看大佬怎麼說之>> B站-微積分的本質 - 系列合集

  • 擬合:形象的說,擬合就是把平面上一系列的點,用一條光滑的曲線連線起來的過程。找到一條最符合這些散點的曲線,使得儘可能多的落在曲線上。常用的方法是最小二乘法。也就是最小二乘問題


新增依賴

Maven 中新增依賴

<!-- https://mvnrepository.com/artifact/org.apache.commons/commons-math3 -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>

如果你是 Gradle

// https://mvnrepository.com/artifact/org.apache.commons/commons-math3
compile group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1'

如何使用和驗證

  1. 假設函式已知
  2. 根據函式並新增隨機數R生成一系列散點資料(藍色)
  3. 進行擬合,根據擬合結果生成擬合曲線
  4. 對比結果曲線(綠色)和散點曲線

例如

\[f(x) = 2x + 3
\]

首先根絕函式生成 \(x\) 取任意實數時的以及所對應的 \(f(x)\) 得到資料集 \(xy\)

\[f(x,y) = (0,3)*R, (1,5)*R, (2,7)*R...(n,2n+3)*R
\]

然後對這組資料進行擬合,然後和已知函式 \(f(x)\) 對比斜率 \(k\) 以及截距 \(b\)


1. 線性擬合

線性函式:

\[f(x) = kx + b
\]

假設函式為:

\[f(x) = 1.5x + 0.5
\]

生成資料集合:

/**
*
* y = kx + b
* f(x) = 1.5x + 0.5
*
* @return
*/
public static double[][] linearScatters() {
List<double[]> data = new ArrayList<>();
for (double x = 0; x <= 10; x += 0.1) {
double y = 1.5 * x + 0.5;
y += Math.random() * 4 - 2; // 隨機數
double[] xy = {x, y};
data.add(xy);
}
return data.stream().toArray(double[][]::new);
}

進行擬合

public static Result linearFit(double[][] data) {
List<double[]> fitData = new ArrayList<>();
SimpleRegression regression = new SimpleRegression();
regression.addData(data); // 資料集
/*
* RegressionResults 中是擬合的結果
* 其中重要的幾個引數如下:
* parameters:
* 0: b
* 1: k
* globalFitInfo
* 0: 平方誤差之和, SSE
* 1: 平方和, SST
* 2: R 平方, RSQ
* 3: 均方誤差, MSE
* 4: 調整後的 R 平方, adjRSQ
*
* */
RegressionResults results = regression.regress();
double b = results.getParameterEstimate(0);
double k = results.getParameterEstimate(1);
double r2 = results.getRSquared(); // 重新計算生成擬合曲線
for (double[] datum : data) {
double[] xy = {datum[0], k * datum[0] + b};
fitData.add(xy);
} StringBuilder func = new StringBuilder();
func.append("f(x) =");
func.append(b >= 0 ? " " : " - ");
func.append(Math.abs(b));
func.append(k > 0 ? " + " : " - ");
func.append(Math.abs(k));
func.append("x"); return new Result(fitData.stream().toArray(double[][]::new), func.toString());
}

擬合效果

線性擬合比較簡單,主要是 SimpleRegression 類的 regress() 方法,預設使用 最小二乘法優化器


2. 非線性(曲線)擬合(一元多項式)

非線性函式

\[f(x) = a + bx + cx^2 + dx^3 +...+ mx^n
\]

假設函式為

\[f(x) = 1 + 2x + 3x^2
\]

生成資料集合:

/**
*
* f(x) = 1 + 2x + 3x^2
*
* @return
*/
public static double[][] curveScatters() {
List<double[]> data = new ArrayList<>();
for (double x = 0; x <= 20; x += 1) {
double y = 1 + 2 * x + 3 * x * x;
y += Math.random() * 60 - 10; // 隨機數
double[] xy = {x, y};
data.add(xy);
}
return data.stream().toArray(double[][]::new);
}

進行擬合

public static Result curveFit(double[][] data) {
ParametricUnivariateFunction function = new PolynomialFunction.Parametric();/*多項式函式*/
double[] guess = {1, 2, 3}; /*猜測值 依次為 常數項、1次項、二次項*/ // 初始化擬合
SimpleCurveFitter curveFitter = SimpleCurveFitter.create(function,guess); // 新增資料點
WeightedObservedPoints observedPoints = new WeightedObservedPoints();
for (double[] point : data) {
observedPoints.add(point[0], point[1]);
}
/*
* best 為擬合結果
* 依次為 常數項、1次項、二次項
* 對應 y = a + bx + cx^2 中的 a, b, c
* */
double[] best = curveFitter.fit(observedPoints.toList()); /*
* 根據擬合結果重新計算
* */
List<double[]> fitData = new ArrayList<>();
for (double[] datum : data) {
double x = datum[0];
double y = best[0] + best[1] * x + best[2] * x * x; // y = a + bx + cx^2
double[] xy = {x, y};
fitData.add(xy);
} StringBuilder func = new StringBuilder();
func.append("f(x) =");
func.append(best[0] > 0 ? " " : " - ");
func.append(Math.abs(best[0]));
func.append(best[1] > 0 ? " + " : " - ");
func.append(Math.abs(best[1]));
func.append("x");
func.append(best[2] > 0 ? " + " : " - ");
func.append(Math.abs(best[2]));
func.append("x^2"); return new Result(fitData.stream().toArray(double[][]::new), func.toString());
}

擬合效果

一元多項式曲線的擬合多了一些步驟。但是總歸也是不難的。主要是 SimpleCurveFitter 類以及 ParametricUnivariateFunction 介面。

3. 自定義函式擬合(一元多項式)

總得來說,貌似線性和一元多項式都不難。不過,實際工作或者學術中,一般都是自定義的函式。

假設有一元多項式函式:

\[f(x) = d + \frac{a-d}{1 + (\frac{x}{c})^b}
\]

需要擬合出 a,b,c,d 四個引數的值。

方法:

  1. 實現 ParametricUnivariateFunction 介面
  2. 自定義函式,實現 value 方法
  3. 解偏微分方程,實現 gradient 方法
  4. 設定需要擬合的點
  5. 呼叫SimpleCurveFitter#fit 方法進行擬合

不著急寫程式碼,先看ParametricUnivariateFunction 這個介面的原始碼:

/**
* An interface representing a real function that depends on one independent
* variable plus some extra parameters.
*
* @since 3.0
*/
public interface ParametricUnivariateFunction {
/**
* Compute the value of the function.
* 計算函式的值
* @param x Point for which the function value should be computed.
* @param parameters Function parameters.
* @return the value.
*/
double value(double x, double ... parameters); /**
* Compute the gradient of the function with respect to its parameters.
* 計算函式相對於某個引數的導數
* @param x Point for which the function value should be computed.
* @param parameters Function parameters.
* @return the value.
*/
double[] gradient(double x, double ... parameters);
}
  • value 方法很簡單,就是說計算函式 \(F(x)\) 的值。說人話就是自定義函式的
  • gradient 方法為返回一個數組,其實意思就是求偏微分方程,對每一個要擬合的引數求導就行

不會偏微分方程? 點這裡

按格式輸入你的方程=>輸入自變數=>輸入求導階數(一般都是 1 階)=>計算

好了開始寫程式碼吧,假設函式如下:

\[f(x) = d + \frac{a-d}{1 + (\frac{x}{c})^b}
\]
  1. 自定義 MyFunction 實現 ParametricUnivariateFunction 介面:
static class MyFunction implements ParametricUnivariateFunction {
public double value(double x, double ... parameters) {
double a = parameters[0];
double b = parameters[1];
double c = parameters[2];
double d = parameters[3];
return d + ((a - d) / (1 + Math.pow(x / c, b)));
} public double[] gradient(double x, double ... parameters) {
double a = parameters[0];
double b = parameters[1];
double c = parameters[2];
double d = parameters[3]; double[] gradients = new double[4];
double den = 1 + Math.pow(x / c, b); gradients[0] = 1 / den; // 對 a 求導 gradients[1] = -((a - d) * Math.pow(x / c, b) * Math.log(x / c)) / (den * den); // 對 b 求導 gradients[2] = (b * Math.pow(x / c, b - 1) * (x / (c * c)) * (a - d)) / (den * den); // 對 c 求導 gradients[3] = 1 - (1 / den); // 對 d 求導 return gradients; }
}

生成資料散點

/**
*
* <pre>
* f(x) = d + ((a - d) / (1 + Math.pow(x / c, b)))
* a = 1500
* b = 0.95
* c = 65
* d = 35000
* </pre>
*
* @return
*/
public static double[][] customizeFuncScatters() {
MyFunction function = new MyFunction();
List<double[]> data = new ArrayList<>();
for (double x = 7; x <= 10000; x *= 1.5) {
double y = function.value(x, 1500, 0.95, 65, 35000);
y += Math.random() * 5000 - 2000; // 隨機數
double[] xy = {x, y};
data.add(xy);
}
return data.stream().toArray(double[][]::new);
}

擬合自定義函式

public static Result customizeFuncFit(double[][] scatters) {
ParametricUnivariateFunction function = new MyFunction();/*多項式函式*/
double[] guess = {1500, 0.95, 65, 35000}; /*猜測值 依次為 a b c d 。必須和 gradient 方法返回陣列對應。如果不知道都設定為 1*/ // 初始化擬合
SimpleCurveFitter curveFitter = SimpleCurveFitter.create(function,guess); // 新增資料點
WeightedObservedPoints observedPoints = new WeightedObservedPoints();
for (double[] point : scatters) {
observedPoints.add(point[0], point[1]);
} /*
* best 為擬合結果 對應 a b c d
* 可能會出現無法擬合的情況
* 需要合理設定初始值
* */
double[] best = curveFitter.fit(observedPoints.toList());
double a = best[0];
double b = best[1];
double c = best[2];
double d = best[3]; // 根據擬合結果生成擬合曲線散點
List<double[]> fitData = new ArrayList<>();
for (double[] datum : scatters) {
double x = datum[0];
double y = function.value(x, a, b, c, d);
double[] xy = {x, y};
fitData.add(xy);
} // f(x) = d + ((a - d) / (1 + Math.pow(x / c, b)))
StringBuilder func = new StringBuilder();
func.append("f(x) =");
func.append(d > 0 ? " " : " - ");
func.append(Math.abs(d));
func.append(" ((");
func.append(a > 0 ? "" : "-");
func.append(Math.abs(a));
func.append(d > 0 ? " - " : " + ");
func.append(Math.abs(d));
func.append(" / (1 + ");
func.append("(x / ");
func.append(c > 0 ? "" : " - ");
func.append(Math.abs(c));
func.append(") ^ ");
func.append(b > 0 ? " " : " - ");
func.append(Math.abs(b)); return new Result(fitData.stream().toArray(double[][]::new), func.toString());
}

擬合效果

4. 多元多項式擬合

我用的 javafx8 版本不支援 WebGL 所以無法通過按鈕直接直觀展示擬合效果。我用擬合前得資料和擬合後重新計算的資料進行對比

** 方程 **

\[f(x_1,x_2) = y = a + b * x_1 + c * sin(x_2)
\]

4.1 構造資料

假設: \(a = 20, b = 2, c = 12\) ,則函式 \(f\) 為 \(f(x_1,x_2) = y = 20 + 2 * x_1 + 12 * sin(x_2)\)

根據這個函式構造資料

/**
* 生成隨機數
*/
public static double[][] randomX() {
List<double[]> data = new ArrayList<>();
for (double i = 0; i < 10; i += 0.1) {
double x1 = Math.cos(i);
double x2 = Math.sin(i);
data.add(new double[]{x1, x2});
}
return data.stream().toArray(double[][]::new);
} /**
* f(x1,x2) = y = a + b * x1 + c * sin(x2)
* @param arr
* @return
*/
public static double[] randomY(double[][] arr) {
if (arr != null && arr.length > 0) {
int len = arr.length;
double[] y = new double[len];
for (int i = 0; i < len; i++) {
// f(x1,x2) = y = 20 + x1 + 12 * sin(x2)
double[] x = arr[i];
// 構造資料
y[i] = functionConstructorY(x);
}
return y;
}
return null;
} /**
* 已知的函式為: f(x1,x2) = y = 20 + 2 * x1 + 12 * sin(x2)
* 即:f(x1,x2) = y = a + b * x1 + c * sin(x2) 中
* a = 20, b = 2, c = 12
* @param x
* @return
*/
public static double functionConstructorY(double[] x) {
double x1 = x[0], x2 = x[1];
return 20 + 2 * x1 + Math.sin(10 * x2);
}

4.2 擬合

多元多項式的擬合主要用到 MultipleLinearRegression 介面,它有三個實現方式。我們選擇最小二乘法的實現 OLSMultipleLinearRegression

/**
* 多元多項式資料
* 已知: f(x1,x2) = y = a + b * x1 + c * sin(x2)
*
*/
public static double[][] multiVarPolyScatters() {
double[][] x = randomX();
double[] y = randomY(x);
OLSMultipleLinearRegression ols = new OLSMultipleLinearRegression();
ols.newSampleData(y, x);
// ct 擬合的常數項(係數)。對應 a,b,c
double[] ct = ols.estimateRegressionParameters();
}

4.3 驗證

根據上面的擬合結果重新計算 \(f(x_1,x_2)\) 的值

/**
* f(x1,x2) = y = a + b * x1 + c * sin(x2)
* @param ct 擬合的常數項(係數)。對應 a,b,c
* @param x x 的值。對應 x1,x2
* @return
*/
public static double functionValueY(double[] ct, double[] x) {
double a = ct[0], b = ct[1], c = ct[2];
double x1 = x[0], x2 = x[1];
return a + b * x1 + Math.sin(c * x2);
} /**
* 多元多項式資料
* 已知: f(x1,x2) = y = a + b * x1 + c * sin(x2)
* @return
* arr[0] 對應所有的 y 的值
* arr[1] 對應所有的 x1 的值
* arr[2] 對應所有的 x2 的值
*/
public static double[][] multiVarPolyScatters() {
double[][] x = randomX();
double[] y = randomY(x);
OLSMultipleLinearRegression ols = new OLSMultipleLinearRegression();
ols.newSampleData(y, x);
// ct 即為擬合結果
double[] ct = ols.estimateRegressionParameters(); double[] valueY = new double[x.length];
for (int i = 0; i < x.length; i++) {
// 重新計算 y 的值。與原有構造的 y 對比
valueY[i] = functionValueY(ct, x[i]);
} // 散點資料用於 Echarts 畫圖
double[][] data = new double[x.length][3];// x1, x2, y
for (int i = 0; i < valueY.length; i++) {
// ==================== x1 ====== x2 ======= y ====
data[i] = new double[]{x[i][0], x[i][1], valueY[i]};
}
return data;
}

4.4 畫圖

Echarts 3D畫圖的工具在 https://echarts.apache.org/examples/zh/editor.html?c=line3d-orthographic&gl=1 這個地方。我們將構造資料的函式改為我們的

// ...
var data = [];
// Parametric curve
for (var t = 0; t < 10; t += 0.1) {
// 這裡改成我們的函式。其他的都不變
var x = Math.cos(t);
var y = Math.sin(t);
var z = 20 + 2 * x + 12 * Math.sin(y);
data.push([x, y, z]);
}
// ...

那可以得到這樣一張圖

然後我們執行 org.wfw.chart.data.MultipleLinearRegressionData#main() 方法後將得到的資料整個賦值給 data 覆蓋也行。我們就得到了如下的圖

擬合的結果是 $$ a = 20.01068756847646, b = 2.036022472817587, c = 10.571979017911016 $$ 和我們一開始的確定好的值也差不多

4.5 多說兩句

  • calculateRSquared() 計算 \(R^2\)
  • calculateAdjustedRSquared() 計算 \(ajdRSQ\) ,調整後的 \(R^2\)
  • estimateRegressionParameters() 擬合常數項

關於 newSampleData() 方法引數的 y 和 x 樣本

/**
* Loads model x and y sample data, overriding any previous sample.
*
* Computes and caches QR decomposition of the X matrix.
* @param y the [n,1] array representing the y sample
* @param x the [n,k] array representing the x sample
* @throws MathIllegalArgumentException if the x and y array data are not
* compatible for the regression
*/
public void newSampleData(double[] y, double[][] x) throws MathIllegalArgumentException {
validateSampleData(x, y);
newYSampleData(y);
newXSampleData(x);
}

原始碼是這樣的,y 就是 \(f(x_1,x_2)\) 的值,而 x 中的 k 代表的是 \(x_1,x_2\)​ 的值,是順序對應的