1. 程式人生 > >Spark-MLlib的快速使用之六(迴歸分析之邏輯迴歸)

Spark-MLlib的快速使用之六(迴歸分析之邏輯迴歸)

(1)演算法描述

邏輯迴歸(Logistic Regression是用於處理因變數為分類變數的迴歸問題,常見的是二分類或二項分佈問題,也可以處理多分類問題,它實際上是屬於一種分類方法

(2)測試資料

1 1:-0.222222 2:0.5 3:-0.762712 4:-0.833333

1 1:-0.555556 2:0.25 3:-0.864407 4:-0.916667

1 1:-0.722222 2:-0.166667 3:-0.864407 4:-0.833333

1 1:-0.722222 2:0.166667 3:-0.694915 4:-0.916667

0 1:0.166667 2:-0.416667 3:0.457627 4:0.5

1 1:-0.833333 3:-0.864407 4:-0.916667

2 1:-1.32455e-07 2:-0.166667 3:0.220339 4:0.0833333

2 1:-1.32455e-07 2:-0.333333 3:0.0169491 4:-4.03573e-08

1 1:-0.5 2:0.75 3:-0.830508 4:-1

0 1:0.611111 3:0.694915 4:0.416667

0 1:0.222222 2:-0.166667 3:0.423729 4:0.583333

1 1:-0.722222 2:-0.166667 3:-0.864407 4:-1

1 1:-0.5 2:0.166667 3:-0.864407 4:-0.916667

2 1:-0.222222 2:-0.333333 3:0.0508474 4:-4.03573e-08

2 1:-0.0555556 2:-0.833333 3:0.0169491 4:-0.25

2 1:-0.166667 2:-0.416667 3:-0.0169491 4:-0.0833333

(3)測試程式碼

public class JavaMulticlassClassificationMetricsExample {

public static void main(String[] args) {

SparkConf conf = new SparkConf().setAppName("Multi class Classification Metrics Example");

SparkContext sc = new SparkContext(conf);

// $example on$

String path = "sample_multiclass_classification_data.txt";

JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();

// Split initial RDD into two... [60% training data, 40% testing data].

JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L);

JavaRDD<LabeledPoint> training = splits[0].cache();

JavaRDD<LabeledPoint> test = splits[1];

// Run training algorithm to build the model.

final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()

.setNumClasses(3)

.run(training.rdd());

 

// Compute raw scores on the test set.

JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(

new Function<LabeledPoint, Tuple2<Object, Object>>() {

public Tuple2<Object, Object> call(LabeledPoint p) {

Double prediction = model.predict(p.features());

return new Tuple2<Object, Object>(prediction, p.label());

}

}

);

System.out.println("--------------------->"+predictionAndLabels.take(10));

// Get evaluation metrics.

MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());

 

// Confusion matrix

Matrix confusion = metrics.confusionMatrix();

System.out.println("Confusion matrix: \n" + confusion);

 

// Overall statistics

System.out.println("Precision = " + metrics.precision());

System.out.println("Recall = " + metrics.recall());

System.out.println("F1 Score = " + metrics.fMeasure());

 

// Stats by labels

for (int i = 0; i < metrics.labels().length; i++) {

System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision

(metrics.labels()[i]));

System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics

.labels()[i]));

System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure

(metrics.labels()[i]));

}

//Weighted stats

System.out.format("Weighted precision = %f\n", metrics.weightedPrecision());

System.out.format("Weighted recall = %f\n", metrics.weightedRecall());

System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure());

System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate());

 

// Save and load model

model.save(sc, "target/tmp/LogisticRegressionModel");

LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc,

"target/tmp/LogisticRegressionModel");

// $example off$

}

(4)測試結果

 

>[(1.0,1.0), (1.0,1.0), (0.0,0.0), (0.0,0.0), (1.0,1.0), (1.0,1.0), (2.0,2.0), (1.0,1.0), (2.0,2.0), (0.0,0.0)]