1. 程式人生 > >Spark-MLlib的快速使用之五(梯度提升樹GBT 迴歸)

Spark-MLlib的快速使用之五(梯度提升樹GBT 迴歸)

(1)描述

 梯度提升樹(GBT)是決策樹的集合。 GBT迭代地訓練決策樹以便使損失函式最小化。 spark.ml實現支援GBT用於二進位制分類和迴歸,可以使用連續和分類特徵。

(2)測試資料

1 153:5 154:63 155:197 181:20 182:254 183:230 184:24 209:20 210:254 211:254 212:48 237:20 238:254 239:255 240:48 265:20 266:254 267:254 268:57 293:20 294:254 295:254 296:108 321:16 322:239 323:254 324:143 350:178 351:254 352:143 378:178 379:254 380:143 406:178 407:254 408:162 434:178 435:254 436:240 462:113 463:254 464:240 490:83 491:254 492:245 493:31 518:79 519:254 520:246 521:38 547:214 548:254 549:150 575:144 576:241 577:8 603:144 604:240 605:2 631:144 632:254 633:82 659:230 660:247 661:40 687:168 688:209 689:31

(3)測試程式

// $example on$

SparkConf sparkConf = new SparkConf()

.setAppName("JavaGradientBoostedTreesRegressionExample").setMaster("local");

JavaSparkContext jsc = new JavaSparkContext(sparkConf);

// Load and parse the data file.

String datapath = "sample_libsvm_data.txt";

JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();

// Split the data into training and test sets (30% held out for testing)

JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});

JavaRDD<LabeledPoint> trainingData = splits[0];

JavaRDD<LabeledPoint> testData = splits[1];

 

// Train a GradientBoostedTrees model.

// The defaultParams for Regression use SquaredError by default.

BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression");

boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice.

boostingStrategy.getTreeStrategy().setMaxDepth(5);

// Empty categoricalFeaturesInfo indicates all features are continuous.

Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();

boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo);

 

final GradientBoostedTreesModel model = GradientBoostedTrees.train(trainingData, boostingStrategy);

 

// Evaluate model on test instances and compute test error

JavaPairRDD<Double, Double> predictionAndLabel =

testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {

@Override

public Tuple2<Double, Double> call(LabeledPoint p) {

return new Tuple2<Double, Double>(model.predict(p.features()), p.label());

}

});

System.out.println(predictionAndLabel.take(10));

Double testMSE =

predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {

@Override

public Double call(Tuple2<Double, Double> pl) {

Double diff = pl._1() - pl._2();

return diff * diff;

}

}).reduce(new Function2<Double, Double, Double>() {

@Override

public Double call(Double a, Double b) {

return a + b;

}

}) / data.count();

System.out.println("Test Mean Squared Error: " + testMSE);

System.out.println("Learned regression GBT model:\n" + model.toDebugString());

 

// Save and load model

model.save(jsc.sc(), "target/tmp/myGradientBoostingRegressionModel");

GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(),

"target/tmp/myGradientBoostingRegressionModel");

// $example off$

}

(4)測試結果

[(0.0,0.0), (1.0,1.0), (1.0,1.0), (0.0,0.0), (0.0,0.0), (1.0,1.0), (1.0,1.0), (0.0,0.0), (1.0,1.0), (0.0,0.0)]