1. 程式人生 > >Spark-MLlib的快速使用之四(梯度提升樹GBT 分類)

Spark-MLlib的快速使用之四(梯度提升樹GBT 分類)

(1)描述

 梯度提升樹(GBT)是決策樹的集合。 GBT迭代地訓練決策樹以便使損失函式最小化。 spark.ml實現支援GBT用於二進位制分類和迴歸,可以使用連續和分類特徵。

(2)測試資料

1 153:5 154:63 155:197 181:20 182:254 183:230 184:24 209:20 210:254 211:254 212:48 237:20 238:254 239:255 240:48 265:20 266:254 267:254 268:57 293:20 294:254 295:254 296:108 321:16 322:239 323:254 324:143 350:178 351:254 352:143 378:178 379:254 380:143 406:178 407:254 408:162 434:178 435:254 436:240 462:113 463:254 464:240 490:83 491:254 492:245 493:31 518:79 519:254 520:246 521:38 547:214 548:254 549:150 575:144 576:241 577:8 603:144 604:240 605:2 631:144 632:254 633:82 659:230 660:247 661:40 687:168 688:209 689:31

(3)測試程式

public class JavaGradientBoostingClassificationExample {

public static void main(String[] args) {

// $example on$

SparkConf sparkConf = new SparkConf()

.setAppName("JavaGradientBoostedTreesClassificationExample").setMaster("local");

JavaSparkContext jsc = new JavaSparkContext(sparkConf);

 

// Load and parse the data file.

String datapath = "sample_libsvm_data.txt";

JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();

// Split the data into training and test sets (30% held out for testing)

JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});

JavaRDD<LabeledPoint> trainingData = splits[0];

JavaRDD<LabeledPoint> testData = splits[1];

 

// Train a GradientBoostedTrees model.

// The defaultParams for Classification use LogLoss by default.

BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification");

boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice.

boostingStrategy.getTreeStrategy().setNumClasses(2);

boostingStrategy.getTreeStrategy().setMaxDepth(5);

// Empty categoricalFeaturesInfo indicates all features are continuous.

Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();

boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo);

 

final GradientBoostedTreesModel model =

GradientBoostedTrees.train(trainingData, boostingStrategy);

 

// Evaluate model on test instances and compute test error

JavaPairRDD<Double, Double> predictionAndLabel =

testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {

@Override

public Tuple2<Double, Double> call(LabeledPoint p) {

return new Tuple2<Double, Double>(model.predict(p.features()), p.label());

}

});

System.out.println(predictionAndLabel.take(10));

Double testErr =

1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {

@Override

public Boolean call(Tuple2<Double, Double> pl) {

return !pl._1().equals(pl._2());

}

}).count() / testData.count();

System.out.println("Test Error: " + testErr);

System.out.println("Learned classification GBT model:\n" + model.toDebugString());

 

// Save and load model

model.save(jsc.sc(), "target/tmp/myGradientBoostingClassificationModel");

GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(),

"target/tmp/myGradientBoostingClassificationModel");

// $example off$

}

(4)測試結果

[(0.0,0.0), (1.0,1.0), (1.0,1.0), (0.0,0.0), (0.0,0.0), (1.0,1.0), (1.0,1.0), (0.0,0.0), (1.0,1.0), (0.0,0.0)]