1. 程式人生 > >Spark2.3.2 機器學習工作流構建

Spark2.3.2 機器學習工作流構建

scala> import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.SparkSession

scala> val spark = SparkSession.builder().
     |             master("local").
     |             appName("my App Name").
     |             getOrCreate()
2018-12-07 02:14:10 WARN  SparkSession$Builder:66 - Using an existing SparkSession; some configuration may not take effect.
spark: org.apache.spark.sql.SparkSession = 
[email protected]
scala> import org.apache.spark.ml.feature._ import org.apache.spark.ml.feature._ scala> import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.classification.LogisticRegression scala> import org.apache.spark.ml.{Pipeline,PipelineModel} import org.apache.spark.ml.{Pipeline, PipelineModel} scala> import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.linalg.Vector scala> import org.apache.spark.sql.Row import org.apache.spark.sql.Row scala> val training = spark.createDataFrame(Seq((0L, "a b c d e spark", 1.0),(1L, "b d", 0.0),(2L, "spark f g h", 1.0),(3L, "hadoop mapreduce", 0.0))).toDF("id", "text", "label") 2018-12-07 02:15:29 WARN ObjectStore:568 - Failed to get database global_temp, returning NoSuchObjectException training: org.apache.spark.sql.DataFrame = [id: bigint, text: string ... 1 more field] scala> scala> val tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words") tokenizer: org.apache.spark.ml.feature.Tokenizer = tok_b90cb26b1f51 scala> val hashingTF = new HashingTF().setNumFeatures(1000).setInputCol(tokenizer.getOutputCol).setOutputCol("features") hashingTF: org.apache.spark.ml.feature.HashingTF = hashingTF_e810c12ed27c scala> val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01) lr: org.apache.spark.ml.classification.LogisticRegression = logreg_fdee17135e3d scala> val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, lr)) pipeline: org.apache.spark.ml.Pipeline = pipeline_a9b6a2d92374 scala> val model = pipeline.fit(training) 2018-12-07 02:16:55 WARN BLAS:61 - Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS 2018-12-07 02:16:55 WARN BLAS:61 - Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS model: org.apache.spark.ml.PipelineModel = pipeline_a9b6a2d92374 scala> val test = spark.createDataFrame(Seq((4L, "spark i j k"),(5L, "l m n"),(6L, "spark a"),(7L, "apache hadoop"))).toDF("id", "text") test: org.apache.spark.sql.DataFrame = [id: bigint, text: string] scala> model.transform(test).select("id", "text", "probability", "prediction").collect().foreach {case Row(id: Long, text: String, prob: Vector, prediction: Double) => println(s"($id, $text) --> prob=$prob, prediction=$prediction")} (4, spark i j k) --> prob=[0.540643354485232,0.45935664551476796], prediction=0.0 (5, l m n) --> prob=[0.9334382627383527,0.06656173726164716], prediction=0.0 (6, spark a) --> prob=[0.1504143004807332,0.8495856995192668], prediction=1.0 (7, apache hadoop) --> prob=[0.9768636139518375,0.02313638604816238], prediction=0.0