1. 程式人生 > >一文讀懂 超簡單的 structured stream 源碼解讀

一文讀懂 超簡單的 structured stream 源碼解讀

ket exec res exce bus sin imp += work

為了讓大家理解structured stream的運行流程,我將根據一個代碼例子,講述structured stream的基本運行流程和原理。

下面是一段簡單的代碼:

 1 val spark = SparkSession
 2       .builder
 3       .appName("StructuredNetworkWordCount")
 4       .master("local[4]")
 5 
 6       .getOrCreate()
 7     spark.conf.set("spark.sql.shuffle.partitions", 4)
 8 
 9     import spark.implicits._
10 val words = spark.readStream 11 .format("socket") 12 .option("host", "localhost") 13 .option("port", 9999) 14 .load() 15 16 val df1 = words.as[String] 17 .flatMap(_.split(" ")) 18 .toDF("word") 19 .groupBy("word") 20 .count() 21 22 df1.writeStream
23 .outputMode("complete") 24 .format("console") 25 .trigger(ProcessingTime(10)) 26 .start() 27 28 spark.streams.awaitAnyTermination()

  這段代碼就是單詞計數。先從一個socket數據源讀入數據,然後以" " 為分隔符把一行文本轉換成單詞的DataSet,然後轉換成有標簽("word")的DataFrame,接著按word列進行分組,聚合計算每個word的個數。最後輸出到控制臺,以10秒為批處理執行周期。

  

現在來分析它的原理。spark的邏輯裏面有一個惰性計算的概念,以上面的例子來說,在第22行代碼以前,程序都不會對數據進行真正的計算,而是將計算的公式(或者函數)保存在DataFrame裏面,在22行開始的writeStream.start調用後才開始真正的計算。為什麽?

因為:

這可以讓spark內核做一些優化。

例如:

數據庫中存放著人的名字和年齡,我想要在控制臺打印出前十個年齡大於20歲的人的名字,那麽我的spark代碼會這麽寫:

1 df.fileter{row=>
2 row._2>20}
3 .show(10)

假如說我每執行一行代碼就進行一次計算,那麽在第二行的時候,我就會把df裏面所有的數據進行過濾,篩選出其中年齡大於20的,然後在第3行執行的時候,從第2行裏面的結果中選前面10個進行打印。

看出問題了麽?這裏的輸出僅僅只需要10個年齡大於20的人,但是我卻把所有人都篩選了一遍,其實我只需要篩選出10個,後面的就不必要篩選了。這就是spark的惰性計算進行優化的地方。

在spark的計算中,在真正的輸出函數之前,都不會進行真正的計算,而會在輸出函數之前進行優化後再進行計算。我們來看源代碼。

這裏我貼的是structured stream每次批處理周期到達時會運行的代碼:

 1  private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = {
 2     // Request unprocessed data from all sources.
 3     newData = reportTimeTaken("getBatch") {
 4       availableOffsets.flatMap {
 5         case (source, available)
 6           if committedOffsets.get(source).map(_ != available).getOrElse(true) =>
 7           val current = committedOffsets.get(source)
 8           val batch = source.getBatch(current, available)
 9           logDebug(s"Retrieving data from $source: $current -> $available")
10           Some(source -> batch)
11         case _ => None
12       }
13     }
14 
15     // A list of attributes that will need to be updated.
16     var replacements = new ArrayBuffer[(Attribute, Attribute)]
17     // Replace sources in the logical plan with data that has arrived since the last batch.
18     val withNewSources = logicalPlan transform {
19       case StreamingExecutionRelation(source, output) =>
20         newData.get(source).map { data =>
21           val newPlan = data.logicalPlan
22           assert(output.size == newPlan.output.size,
23             s"Invalid batch: ${Utils.truncatedString(output, ",")} != " +
24             s"${Utils.truncatedString(newPlan.output, ",")}")
25           replacements ++= output.zip(newPlan.output)
26           newPlan
27         }.getOrElse {
28           LocalRelation(output)
29         }
30     }
31 
32     // Rewire the plan to use the new attributes that were returned by the source.
33     val replacementMap = AttributeMap(replacements)
34     val triggerLogicalPlan = withNewSources transformAllExpressions {
35       case a: Attribute if replacementMap.contains(a) => replacementMap(a)
36       case ct: CurrentTimestamp =>
37         CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
38           ct.dataType)
39       case cd: CurrentDate =>
40         CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
41           cd.dataType, cd.timeZoneId)
42     }
43 
44     reportTimeTaken("queryPlanning") {
45       lastExecution = new IncrementalExecution(
46         sparkSessionToRunBatch,
47         triggerLogicalPlan,
48         outputMode,
49         checkpointFile("state"),
50         currentBatchId,
51         offsetSeqMetadata)
52       lastExecution.executedPlan // Force the lazy generation of execution plan
53     }
54 
55     val nextBatch =
56       new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema))
57 
58     reportTimeTaken("addBatch") {
59       sink.addBatch(currentBatchId, nextBatch)
60     }
61 
62     awaitBatchLock.lock()
63     try {
64       // Wake up any threads that are waiting for the stream to progress.
65       awaitBatchLockCondition.signalAll()
66     } finally {
67       awaitBatchLock.unlock()
68     }
69   }

其實很簡單,在第58以前都是在解析用戶代碼,生成logicPlan,優化logicPlan,生成批處理類。第47行的triggerLogicalPlan就是最終優化後的用戶邏輯,它被封裝在了一個IncrementalExecution類中,這個類連同sparkSessionToRunBatch(運行環境)和RowEncoder(序列化類)一起構成一個新的DataSet,這個DataSet就是最終要發送到worker節點進行執行的代碼。第59行代碼就是在將它加入到準備發送代碼的隊列中。我們繼續看一段代碼,由於我們使用console作為數據下遊(sink)所以看看console的addBatch代碼:

 1 override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized {
 2     val batchIdStr = if (batchId <= lastBatchId) {
 3       s"Rerun batch: $batchId"
 4     } else {
 5       lastBatchId = batchId
 6       s"Batch: $batchId"
 7     }
 8 
 9     // scalastyle:off println
10     println("-------------------------------------------")
11     println(batchIdStr)
12     println("-------------------------------------------")
13     // scalastyle:off println
14     data.sparkSession.createDataFrame(
15       data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
16       .show(numRowsToShow, isTruncated)
17   }

關鍵代碼在16行.show函數,show函數是一個真正的action,在這之前都是一些算子的封裝,我們看show的代碼:

1 private[sql] def showString(_numRows: Int, truncate: Int = 20): String = {
2     val numRows = _numRows.max(0)
3     val takeResult = toDF().take(numRows + 1)
4     val hasMoreData = takeResult.length > numRows
5     val data = takeResult.take(numRows)

第3行進入take:

  def take(n: Int): Array[T] = head(n)
def head(n: Int): Array[T] = withAction("head", limit(n).queryExecution)(collectFromPlan)
 1 private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
 2     try {
 3       qe.executedPlan.foreach { plan =>
 4         plan.resetMetrics()
 5       }
 6       val start = System.nanoTime()
 7       val result = SQLExecution.withNewExecutionId(sparkSession, qe) {
 8         action(qe.executedPlan)
 9       }
10       val end = System.nanoTime()
11       sparkSession.listenerManager.onSuccess(name, qe, end - start)
12       result
13     } catch {
14       case e: Exception =>
15         sparkSession.listenerManager.onFailure(name, qe, e)
16         throw e
17     }
18   }

這個函數名就告訴我們,這是真正計算要開始了,第7行代碼一看就是準備發送代碼序列了:

 1 def withNewExecutionId[T](
 2       sparkSession: SparkSession,
 3       queryExecution: QueryExecution)(body: => T): T = {
 4     val sc = sparkSession.sparkContext
 5     val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
 6     if (oldExecutionId == null) {
 7       val executionId = SQLExecution.nextExecutionId
 8       sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
 9       executionIdToQueryExecution.put(executionId, queryExecution)
10       val r = try {
11         // sparkContext.getCallSite() would first try to pick up any call site that was previously
12         // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
13         // streaming queries would give us call site like "run at <unknown>:0"
14         val callSite = sparkSession.sparkContext.getCallSite()
15 
16         sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
17           executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
18           SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
19         try {
20           body
21         } finally {
22           sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
23             executionId, System.currentTimeMillis()))
24         }
25       } finally {
26         executionIdToQueryExecution.remove(executionId)
27         sc.setLocalProperty(EXECUTION_ID_KEY, null)
28       }
29       r
30     } else {
31       // Don‘t support nested `withNewExecutionId`. This is an example of the nested
32       // `withNewExecutionId`:
33       //
34       // class DataFrame {
35       //   def foo: T = withNewExecutionId { something.createNewDataFrame().collect() }
36       // }
37       //
38       // Note: `collect` will call withNewExecutionId
39       // In this case, only the "executedPlan" for "collect" will be executed. The "executedPlan"
40       // for the outer DataFrame won‘t be executed. So it‘s meaningless to create a new Execution
41       // for the outer DataFrame. Even if we track it, since its "executedPlan" doesn‘t run,
42       // all accumulator metrics will be 0. It will confuse people if we show them in Web UI.
43       //
44       // A real case is the `DataFrame.count` method.
45       throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set")
46     }
47   }

你看第16行,就是在發送數據,包括用戶優化後的邏輯,批處理的id,時間戳等等。worker接收到這個事件後根據logicalPlan裏面的邏輯就開始幹活了。這就是一個很基本很簡單的流程,對於spark入門還是挺有幫助的吧。

  

一文讀懂 超簡單的 structured stream 源碼解讀