Spark-MLlib的快速使用之四(梯度提升樹GBT 分類)


 梯度提升樹(GBT)是決策樹的集合。 GBT迭代地訓練決策樹以便使損失函式最小化。 spark.ml實現支援GBT用於二進位制分類和迴歸,可以使用連續和分類特徵。


1 153:5 154:63 155:197 181:20 182:254 183:230 184:24 209:20 210:254 211:254 212:48 237:20 238:254 239:255 240:48 265:20 266:254 267:254 268:57 293:20 294:254 295:254 296:108 321:16 322:239 323:254 324:143 350:178 351:254 352:143 378:178 379:254 380:143 406:178 407:254 408:162 434:178 435:254 436:240 462:113 463:254 464:240 490:83 491:254 492:245 493:31 518:79 519:254 520:246 521:38 547:214 548:254 549:150 575:144 576:241 577:8 603:144 604:240 605:2 631:144 632:254 633:82 659:230 660:247 661:40 687:168 688:209 689:31


public class JavaGradientBoostingClassificationExample {

public static void main(String[] args) {

// $example on$

SparkConf sparkConf = new SparkConf()


JavaSparkContext jsc = new JavaSparkContext(sparkConf);


// Load and parse the data file.

String datapath = "sample_libsvm_data.txt";

JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();

// Split the data into training and test sets (30% held out for testing)

JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});

JavaRDD<LabeledPoint> trainingData = splits[0];

JavaRDD<LabeledPoint> testData = splits[1];


// Train a GradientBoostedTrees model.

// The defaultParams for Classification use LogLoss by default.

BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification");

boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice.



// Empty categoricalFeaturesInfo indicates all features are continuous.

Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();



final GradientBoostedTreesModel model =

GradientBoostedTrees.train(trainingData, boostingStrategy);


// Evaluate model on test instances and compute test error

JavaPairRDD<Double, Double> predictionAndLabel =

testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {


public Tuple2<Double, Double> call(LabeledPoint p) {

return new Tuple2<Double, Double>(model.predict(p.features()), p.label());




Double testErr =

1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {


public Boolean call(Tuple2<Double, Double> pl) {

return !pl._1().equals(pl._2());


}).count() / testData.count();

System.out.println("Test Error: " + testErr);

System.out.println("Learned classification GBT model:\n" + model.toDebugString());


// Save and load model

model.save(jsc.sc(), "target/tmp/myGradientBoostingClassificationModel");

GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(),


// $example off$



[(0.0,0.0), (1.0,1.0), (1.0,1.0), (0.0,0.0), (0.0,0.0), (1.0,1.0), (1.0,1.0), (0.0,0.0), (1.0,1.0), (0.0,0.0)]