1. 程式人生 > >Spark 2.x 決策樹 示例程式碼-IRIS資料集

Spark 2.x 決策樹 示例程式碼-IRIS資料集

資料集下載

下載連結

程式碼

package Iris;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.Transformer;
import org.apache
.spark.ml.classification.DecisionTreeClassificationModel; import org.apache.spark.ml.classification.DecisionTreeClassifier; import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; import org.apache.spark.ml.feature.*; import org.apache.spark.ml.linalg.VectorUDT; import org.apache.spark
.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField
; import org.apache.spark.sql.types.StructType; import scala.Tuple2; import util.InitSparkUtil; import java.util.HashMap; import java.util.Map; /** * Created by xy on 2018/4/20. */ public class IrisDT { public static final String[] iris = new String[]{"Iris_setosa", "Iris_versicolor", "Iris_virginica"}; public static void irisDT() { //1、構造SparkSession InitSparkUtil initSparkUtil = new InitSparkUtil(); SparkSession spark = initSparkUtil.getSparkSession("irisDT"); //2、載入資料 Dataset<Row> data = spark.read().csv("E:\\idea工程\\data\\iris.csv"); data = data.toDF("Sepal_Length", "Sepal_Width", "Petal_Length", "Petal_Width", "Species"); JavaRDD<String> dataRdd = data.toJavaRDD().map(x -> x.toString().replace("[", "").replace("]", "")); //3、把資料轉為Row的形式 JavaRDD<Row> irisRowRDD = dataRdd.map(x -> x.split(",")).map(x -> { double[] ds = new double[x.length - 1]; for (int i = 0; i < x.length - 1; i++) { ds[i] = Double.parseDouble(x[i]); } return RowFactory.create(Vectors.dense(ds), x[x.length - 1].replace("-", "_")); }); //4、定義StructType StructType schema = new StructType(new StructField[]{new StructField("features", new VectorUDT(), false, Metadata.empty()), new StructField("label", DataTypes.StringType, false, Metadata.empty())}); //5、分層抽樣 JavaRDD<Row> trainDataRDD = stratifiedSample(irisRowRDD); JavaRDD<Row> testDataRDD = irisRowRDD.subtract(trainDataRDD); Dataset<Row> trainData = spark.createDataFrame(trainDataRDD, schema); Dataset<Row> testData = spark.createDataFrame(testDataRDD, schema); Dataset<Row> fullData = trainData.union(testData); fullData.cache(); trainData.show(150); testData.show(150); fullData.show(2000); /** * 6、fit方法都會產生一個Model。把特徵列進行索引,即列的不同值小於4的,就轉為Int型離散變數,不然就認為是連續值。 * InputCol裡面的值要和StructType裡面的對應上。 */ VectorIndexerModel featureIndexer = new VectorIndexer().setInputCol("features").setMaxCategories(4).setOutputCol("indexedFeatures").fit(fullData); Dataset<Row> featureIndexData = featureIndexer.transform(fullData); featureIndexData.show(200); /** * 7、StringIndexer:把類別這一列,由String轉為標籤,便於計算,即變為int型的離散變數,從0開始。 * 索引的順序是頻率,頻率最大的為0. */ StringIndexerModel labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(fullData); Dataset<Row> labelIndexData = labelIndexer.transform(fullData); labelIndexData.show(200); //8、把預測的類別重新轉為String型 IndexToString labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels()); /** * 9、建立決策樹。setMaxDepth:設定最大深度;setMinInfoGain:最小資訊增益; * setMinInstancesPerNode:某個節點的樣本數小於該值,就不再被分叉。 * setImpurity:使用什麼樣的增益演算法,gini是Gini不純度,entropy是資訊熵。 */ DecisionTreeClassifier dtClassifier = new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxDepth(20).setMinInfoGain(0.00001).setMinInstancesPerNode(1).setImpurity("gini"); //建立Pipeline Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{labelIndexer, featureIndexer, dtClassifier, labelConverter}); /** *Pipeline的2個方法: * fit:傳入DF進行訓練併產生模型,意思就是對資料進行一些統計學習規律,最後得到一個模型。 * transform:將一個DF轉為另一個DF,對資料進行操作,可以對資料進行轉換,進行預測等。 */ //訓練 PipelineModel modelClassifier = pipeline.fit(trainData); //預測 Dataset<Row> predictionClassifier = modelClassifier.transform(testData); predictionClassifier.select("predictedLabel", "label", "features").show(200); //評估 MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy"); double accuracy = evaluator.evaluate(predictionClassifier); System.out.println(accuracy); //模型結構 Transformer dtModel = modelClassifier.stages()[2]; DecisionTreeClassificationModel treeClassModel = (DecisionTreeClassificationModel) dtModel; String treeModelStruct = treeClassModel.toDebugString(); System.out.println(treeModelStruct); fullData.unpersist(); } protected static JavaRDD<Row> stratifiedSample(JavaRDD<Row> irisRowRDD) { JavaPairRDD<String, Row> pariRDD = irisRowRDD.mapToPair(x -> new Tuple2<>(x.getString(1), x)); Map<String, Double> fractions = new HashMap<>(); for (int i = 0; i < iris.length; i++) { fractions.put(iris[i], 0.8); } JavaRDD<Row> trainRDD = pariRDD.sampleByKeyExact(false, fractions, 0).map(x -> x._2); return trainRDD; } public static void main(String[] args) { irisDT(); } }
package util;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;

/**
 * 初始化spark類
 */
public class InitSparkUtil {

    private JavaSparkContext sc;

    public SparkSession getSparkSession(String appname) {
        SparkConf conf = new SparkConf().setMaster("local");
        SparkSession spark = SparkSession.builder().appName(appname).config(conf).getOrCreate();
        return spark;
    }

    public JavaSparkContext getSc(String appname) {
        SparkConf conf = new SparkConf().setMaster("local").setAppName(appname);
        this.sc = new JavaSparkContext(conf);
        return sc;
    }

}