1. 程式人生 > >關於spark的mllib學習總結(Java版)

關於spark的mllib學習總結(Java版)

本篇部落格主要講述如何利用spark的mliib構建機器學習模型並預測新的資料,具體的流程如下圖所示:
基本流程

載入資料

對於資料的載入或儲存,mllib提供了MLUtils包,其作用是Helper methods to load,save and pre-process data used in MLLib.部落格中的資料是採用spark中提供的資料sample_libsvm_data.txt,其有一百個資料樣本,658個特徵。具體的資料形式如圖所示:
資料格式

載入libsvm

JavaRDD<LabeledPoint> lpdata = MLUtils.loadLibSVMFile
(sc, this.libsvmFile).toJavaRDD();

LabeledPoint資料型別是對應與libsvmfile格式檔案, 具體格式為:
Lable(double型別),vector(Vector型別)

轉化dataFrame資料型別

JavaRDD<Row> jrow = lpdata.map(new LabeledPointToRow());
StructType schema = new StructType(new StructField[]{
                    new StructField("label", DataTypes.DoubleType, false
, Metadata.empty()), new StructField("features", new VectorUDT(), false, Metadata.empty()), }); SQLContext jsql = new SQLContext(sc); DataFrame df = jsql.createDataFrame(jrow, schema);

DataFrame:DataFrame是一個以命名列方式組織的分散式資料集。在概念上,它跟關係型資料庫中的一張表或者1個Python(或者R)中的data frame一樣,但是比他們更優化。DataFrame可以根據結構化的資料檔案、hive表、外部資料庫或者已經存在的RDD構造。

SQLContext:spark sql所有功能的入口是SQLContext類,或者SQLContext的子類。為了建立一個基本的SQLContext,需要一個SparkContext。

特徵提取

特徵歸一化處理

StandardScaler scaler = new StandardScaler().setInputCol("features").setOutputCol("normFeatures").setWithStd(true);
DataFrame scalerDF = scaler.fit(df).transform(df);
scaler.save(this.scalerModelPath);

利用卡方統計做特徵提取

ChiSqSelector selector = new ChiSqSelector().setNumTopFeatures(500).setFeaturesCol("normFeatures").setLabelCol("label").setOutputCol("selectedFeatures");
ChiSqSelectorModel chiModel = selector.fit(scalerDF);
DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures");
chiModel.save(this.featureSelectedModelPath);

訓練機器學習模型(以SVM為例)

//轉化為LabeledPoint資料型別, 訓練模型
JavaRDD<Row> selectedrows = selectedDF.javaRDD();
JavaRDD<LabeledPoint> trainset = selectedrows.map(new RowToLabel());

//訓練SVM模型, 並儲存
int numIteration = 200;
SVMModel model = SVMWithSGD.train(trainset.rdd(), numIteration);
model.clearThreshold();
model.save(sc, this.mlModelPath);

// LabeledPoint資料型別轉化為Row
static class LabeledPointToRow implements Function<LabeledPoint, Row> {

        public Row call(LabeledPoint p) throws Exception {
            double label = p.label();
            Vector vector = p.features();
            return RowFactory.create(label, vector);
        }
    }

//Rows資料型別轉化為LabeledPoint
static class RowToLabel implements Function<Row, LabeledPoint> {

        public LabeledPoint call(Row r) throws Exception {
            Vector features = r.getAs(1);
            double label = r.getDouble(0);
            return new LabeledPoint(label, features);
        }
    }

測試新的樣本

測試新的樣本前,需要將樣本做資料的轉化和特徵提取的工作,所有剛剛訓練模型的過程中,除了儲存機器學習模型,還需要儲存特徵提取的中間模型。具體程式碼如下:

//初始化spark
SparkConf conf = new SparkConf().setAppName("SVM").setMaster("local");
conf.set("spark.testing.memory", "2147480000");
SparkContext sc = new SparkContext(conf);

//載入測試資料
JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, this.predictDataPath).toJavaRDD();

//轉化DataFrame資料型別
JavaRDD<Row> jrow =testData.map(new LabeledPointToRow());
        StructType schema = new StructType(new StructField[]{
                    new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
                    new StructField("features", new VectorUDT(), false, Metadata.empty()),
        });
SQLContext jsql = new SQLContext(sc);
DataFrame df = jsql.createDataFrame(jrow, schema);

        //資料規範化
StandardScaler scaler = StandardScaler.load(this.scalerModelPath);
DataFrame scalerDF = scaler.fit(df).transform(df);

        //特徵選取
ChiSqSelectorModel chiModel = ChiSqSelectorModel.load( this.featureSelectedModelPath);
DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures");

測試資料集

SVMModel svmmodel = SVMModel.load(sc, this.mlModelPath);
JavaRDD<Tuple2<Double, Double>> predictResult = testset.map(new Prediction(svmmodel)) ;
predictResult.collect();

static class Prediction implements Function<LabeledPoint, Tuple2<Double , Double>> {
        SVMModel model;
        public Prediction(SVMModel model){
            this.model = model;
        }
        public Tuple2<Double, Double> call(LabeledPoint p) throws Exception {
            Double score = model.predict(p.features());
            return new Tuple2<Double , Double>(score, p.label());
        }
    }

計算準確率

double accuracy = predictResult.filter(new PredictAndScore()).count() * 1.0 / predictResult.count();
System.out.println(accuracy);

static class PredictAndScore implements Function<Tuple2<Double, Double>, Boolean> {
        public Boolean call(Tuple2<Double, Double> t) throws Exception {
            double score = t._1();
            double label = t._2();
            System.out.print("score:" + score + ", label:"+ label);
            if(score >= 0.0 && label >= 0.0) return true;
            else if(score < 0.0 && label < 0.0) return true;
            else return false;
        }
    }