在Java Web中使用Spark MLlib訓練的模型
阿新 • • 發佈:2018-11-15
PMML是一種通用的配置檔案,只要遵循標準的配置檔案,就可以在Spark中訓練機器學習模型,然後再web介面端去使用。目前應用最廣的就是基於Jpmml來載入模型在javaweb中應用,這樣就可以實現跨平臺的機器學習應用了。
訓練模型
首先在spark MLlib中使用mllib包下的邏輯迴歸訓練模型:
import org.apache.spark.mllib.classification.{LogisticRegressionModel, LogisticRegressionWithLBFGS} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils val training = spark.sparkContext .parallelize(Seq("0,1 2 3 1", "1,2 4 1 5", "0,7 8 3 6", "1,2 5 6 9").map( line => LabeledPoint.parse(line))) // Run training algorithm to build the model val model = new LogisticRegressionWithLBFGS() .setNumClasses(2) .run(training) val test = spark.sparkContext .parallelize(Seq("0,1 2 3 1").map( line => LabeledPoint.parse(line))) // Compute raw scores on the test set. val predictionAndLabels = test.map { case LabeledPoint(label, features) => val prediction = model.predict(features) (prediction, label) } // Get evaluation metrics. val metrics = new MulticlassMetrics(predictionAndLabels) val accuracy = metrics.accuracy println(s"Accuracy = $accuracy") // Save and load model // model.save(spark.sparkContext, "target/tmp/scalaLogisticRegressionWithLBFGSModel") // val sameModel = LogisticRegressionModel.load(spark.sparkContext,"target/tmp/scalaLogisticRegressionWithLBFGSModel") model.toPMML(spark.sparkContext, "/tmp/xhl/data/test2")
訓練得到的模型儲存到hdfs。
PMML模型檔案
模型下載到本地,重新命名為xml。
可以看到預設四個特徵分別叫做feild_0
,field_1
...目標為target
<?xml version="1.0" encoding="UTF-8" standalone="yes"?> <PMML version="4.2" xmlns="http://www.dmg.org/PMML-4_2"> <Header description="logistic regression"> <Application name="Apache Spark MLlib" version="2.2.0"/> <Timestamp>2018-11-15T10:22:25</Timestamp> </Header> <DataDictionary numberOfFields="5"> <DataField name="field_0" optype="continuous" dataType="double"/> <DataField name="field_1" optype="continuous" dataType="double"/> <DataField name="field_2" optype="continuous" dataType="double"/> <DataField name="field_3" optype="continuous" dataType="double"/> <DataField name="target" optype="categorical" dataType="string"/> </DataDictionary> <RegressionModel modelName="logistic regression" functionName="classification" normalizationMethod="logit"> <MiningSchema> <MiningField name="field_0" usageType="active"/> <MiningField name="field_1" usageType="active"/> <MiningField name="field_2" usageType="active"/> <MiningField name="field_3" usageType="active"/> <MiningField name="target" usageType="target"/> </MiningSchema> <RegressionTable intercept="0.0" targetCategory="1"> <NumericPredictor name="field_0" coefficient="-5.552297758753701"/> <NumericPredictor name="field_1" coefficient="-1.4863480719075117"/> <NumericPredictor name="field_2" coefficient="-5.7232298850417855"/> <NumericPredictor name="field_3" coefficient="8.134075057437393"/> </RegressionTable> <RegressionTable intercept="-0.0" targetCategory="0"/> </RegressionModel> </PMML>
介面使用
在介面的web工程中引入maven jar:
<!-- https://mvnrepository.com/artifact/org.jpmml/pmml-evaluator --> <dependency> <groupId>org.jpmml</groupId> <artifactId>pmml-evaluator</artifactId> <version>1.4.3</version> </dependency> <!-- https://mvnrepository.com/artifact/org.jpmml/pmml-evaluator-extension --> <dependency> <groupId>org.jpmml</groupId> <artifactId>pmml-evaluator-extension</artifactId> <version>1.4.3</version> </dependency>
介面程式碼中直接讀取pmml,使用模型進行預測:
package soundsystem;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
public class PMMLDemo2 {
private Evaluator loadPmml(){
PMML pmml = new PMML();
try(InputStream inputStream = new FileInputStream("/Users/xingoo/Desktop/test2.xml")){
pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
} catch (Exception e) {
e.printStackTrace();
}
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
return modelEvaluatorFactory.newModelEvaluator(pmml);
}
private Object predict(Evaluator evaluator,int a, int b, int c, int d) {
Map<String, Integer> data = new HashMap<String, Integer>();
data.put("field_0", a);
data.put("field_1", b);
data.put("field_2", c);
data.put("field_3", d);
List<InputField> inputFields = evaluator.getInputFields();
//過模型的原始特徵,從畫像中獲取資料,作為模型輸入
Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
for (InputField inputField : inputFields) {
FieldName inputFieldName = inputField.getName();
Object rawValue = data.get(inputFieldName.getValue());
FieldValue inputFieldValue = inputField.prepare(rawValue);
arguments.put(inputFieldName, inputFieldValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<TargetField> targetFields = evaluator.getTargetFields();
TargetField targetField = targetFields.get(0);
FieldName targetFieldName = targetField.getName();
ProbabilityDistribution target = (ProbabilityDistribution) results.get(targetFieldName);
System.out.println(a + " " + b + " " + c + " " + d + ":" + target);
return target;
}
public static void main(String args[]){
PMMLDemo2 demo = new PMMLDemo2();
Evaluator model = demo.loadPmml();
demo.predict(model,2,5,6,8);
demo.predict(model,7,9,3,6);
demo.predict(model,1,2,3,1);
demo.predict(model,2,4,1,5);
}
}
得到輸出內容:
2 5 6 8:ProbabilityDistribution{result=1, probability_entries=[1=0.9999949538769296, 0=5.046123070395758E-6]}
7 9 3 6:ProbabilityDistribution{result=0, probability_entries=[1=1.1216598160542013E-9, 0=0.9999999988783402]}
1 2 3 1:ProbabilityDistribution{result=0, probability_entries=[1=2.363331367481431E-8, 0=0.9999999763666864]}
2 4 1 5:ProbabilityDistribution{result=1, probability_entries=[1=0.9999999831203591, 0=1.6879640907241367E-8]}
其中result為LR最終的結果,概率為二分類的概率。
參考資料
- 官方文件:https://openscoring.io/
- JPMML官方文件:https://github.com/jpmml/jpmml-evaluator
- jpmml-sklearn:https://github.com/jpmml/jpmml-sklearn
- jpmml-sparkml:https://github.com/jpmml/jpmml-sparkml/tree/master
- 用PMML實現機器學習模型的跨平臺上線:http://www.cnblogs.com/pinard/p/9220199.html
- PMML模型檔案在機器學習的實踐經驗:https://blog.csdn.net/hopeztm/article/details/78321700