1. 程式人生 > >Spark修煉之道(高階篇)——Spark原始碼閱讀:第八節 Task執行

Spark修煉之道(高階篇)——Spark原始碼閱讀:第八節 Task執行

Task執行

在上一節中,我們提到在Driver端CoarseGrainedSchedulerBackend中的launchTasks方法向Worker節點中的Executor傳送啟動任務命令,該命令的接收者是CoarseGrainedExecutorBackend(Standalone模式),類定義原始碼如下:

private[spark] class CoarseGrainedExecutorBackend(
    override val rpcEnv: RpcEnv,
    driverUrl: String,
    executorId: String,
    hostPort: String
, cores: Int, userClassPath: Seq[URL], env: SparkEnv) extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {

可以看到它繼承ThreadSafeRpcEndpoint,它ThreadSafeRpcEndpoint中的receive方法進行了實現,具體原始碼如下:

override def receive: PartialFunction[Any, Unit] = {
    case RegisteredExecutor =>
logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false) case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) System.exit(1
) //處理Driver端傳送過來的LaunchTask命令 case LaunchTask(data) => if (executor == null) { logError("Received LaunchTask command but executor was null") System.exit(1) } else { //對任務進行反序列化 val taskDesc = ser.deserialize[TaskDescription](data.value) logInfo("Got assigned task " + taskDesc.taskId) //Executor啟動任務的執行 executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber, taskDesc.name, taskDesc.serializedTask) } case KillTask(taskId, _, interruptThread) => if (executor == null) { logError("Received KillTask command but executor was null") System.exit(1) } else { executor.killTask(taskId, interruptThread) } case StopExecutor => logInfo("Driver commanded a shutdown") executor.stop() stop() rpcEnv.shutdown() }

從前面的程式碼可以看到,通過 executor.launchTask方法啟動Worker節點上Task的執行,其原始碼如下:

//Executor類中的launchTask方法
def launchTask(
      context: ExecutorBackend,
      taskId: Long,
      attemptNumber: Int,
      taskName: String,
      serializedTask: ByteBuffer): Unit = {
    //建立TaskRunner
    val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
      serializedTask)
    runningTasks.put(taskId, tr)
    //執行緒池執行TaskRunner執行緒,該執行緒中有一個run方法,完成Task的執行
    threadPool.execute(tr)
  }

TaskRunner是一個執行緒,它是一個內部類,被定義在org.apache.spark.executor.Executor類當中,具體原始碼如下:

 class TaskRunner(
      execBackend: ExecutorBackend,
      val taskId: Long,
      val attemptNumber: Int,
      taskName: String,
      serializedTask: ByteBuffer)
    extends Runnable {

    /** Whether this task has been killed. */
    @volatile private var killed = false

    /** How much the JVM process has spent in GC when the task starts to run. */
    @volatile var startGCTime: Long = _

    /**
     * The task to run. This will be set in run() by deserializing the task binary coming
     * from the driver. Once it is set, it will never be changed.
     */
    @volatile var task: Task[Any] = _

    def kill(interruptThread: Boolean): Unit = {
      logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
      killed = true
      if (task != null) {
        task.kill(interruptThread)
      }
    }

    override def run(): Unit = {
      val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
      val deserializeStartTime = System.currentTimeMillis()
      Thread.currentThread.setContextClassLoader(replClassLoader)
      val ser = env.closureSerializer.newInstance()
      logInfo(s"Running $taskName (TID $taskId)")
      //向Driver端髮狀態更新
      execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
      var taskStart: Long = 0
      startGCTime = computeTotalGcTime()

      try {
        val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
        updateDependencies(taskFiles, taskJars)
        task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
        task.setTaskMemoryManager(taskMemoryManager)

        // If this task has been killed before we deserialized it, let's quit now. Otherwise,
        // continue executing the task.
        if (killed) {
          // Throw an exception rather than returning, because returning within a try{} block
          // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
          // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
          // for the task.
          throw new TaskKilledException
        }

        logDebug("Task " + taskId + "'s epoch is " + task.epoch)
        env.mapOutputTracker.updateEpoch(task.epoch)

        // Run the actual task and measure its runtime.
        taskStart = System.currentTimeMillis()
        var threwException = true
        val (value, accumUpdates) = try {
        //執行Task的run方法,不同的Task有不同的實現,例如ShuffleMapTask及ResultTask有各自的實現
          val res = task.run(
            taskAttemptId = taskId,
            attemptNumber = attemptNumber,
            metricsSystem = env.metricsSystem)
          threwException = false
          res
        } finally {
          val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
          if (freedMemory > 0) {
            val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
            if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) {
              throw new SparkException(errMsg)
            } else {
              logError(errMsg)
            }
          }
        }
        val taskFinish = System.currentTimeMillis()

        // If the task has been killed, let's fail it.
        if (task.killed) {
          throw new TaskKilledException
        }

        val resultSer = env.serializer.newInstance()
        val beforeSerialization = System.currentTimeMillis()
        val valueBytes = resultSer.serialize(value)
        val afterSerialization = System.currentTimeMillis()

        for (m <- task.metrics) {
          // Deserialization happens in two parts: first, we deserialize a Task object, which
          // includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
          m.setExecutorDeserializeTime(
            (taskStart - deserializeStartTime) + task.executorDeserializeTime)
          // We need to subtract Task.run()'s deserialization time to avoid double-counting
          m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
          m.setJvmGCTime(computeTotalGcTime() - startGCTime)
          m.setResultSerializationTime(afterSerialization - beforeSerialization)
          m.updateAccumulators()
        }

        val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
        val serializedDirectResult = ser.serialize(directResult)
        val resultSize = serializedDirectResult.limit

        // directSend = sending directly back to the driver
        val serializedResult: ByteBuffer = {
          if (maxResultSize > 0 && resultSize > maxResultSize) {
            logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
              s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
              s"dropping it.")
            ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
          } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
            val blockId = TaskResultBlockId(taskId)
            env.blockManager.putBytes(
              blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
            logInfo(
              s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
            ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
          } else {
            logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
            serializedDirectResult
          }
        }
        //執行完成後,通知Driver端進行狀態更新
        execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

      } catch {
        case ffe: FetchFailedException =>
          val reason = ffe.toTaskEndReason
          execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

        case _: TaskKilledException | _: InterruptedException if task.killed =>
          logInfo(s"Executor killed $taskName (TID $taskId)")
          execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))

        case cDE: CommitDeniedException =>
          val reason = cDE.toTaskEndReason
          execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

        case t: Throwable =>
          // Attempt to exit cleanly by informing the driver of our failure.
          // If anything goes wrong (or this was a fatal exception), we will delegate to
          // the default uncaught exception handler, which will terminate the Executor.
          logError(s"Exception in $taskName (TID $taskId)", t)

          val metrics: Option[TaskMetrics] = Option(task).flatMap { task =>
            task.metrics.map { m =>
              m.setExecutorRunTime(System.currentTimeMillis() - taskStart)
              m.setJvmGCTime(computeTotalGcTime() - startGCTime)
              m.updateAccumulators()
              m
            }
          }
          val serializedTaskEndReason = {
            try {
              ser.serialize(new ExceptionFailure(t, metrics))
            } catch {
              case _: NotSerializableException =>
                // t is not serializable so just send the stacktrace
                ser.serialize(new ExceptionFailure(t, metrics, false))
            }
          }
          //任務失敗時,同樣進行狀態更新,方便後期任務重執行
          execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)

          // Don't forcibly exit unless the exception was inherently fatal, to avoid
          // stopping other tasks unnecessarily.
          if (Utils.isFatalError(t)) {
            SparkUncaughtExceptionHandler.uncaughtException(t)
          }

      } finally {
        //從執行任務列表中刪除
        runningTasks.remove(taskId)
      }
    }
  }

Task run方法負責Task的執行,其原始碼如下:

 /**
   * Called by [[Executor]] to run this task.
   *
   * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
   * @param attemptNumber how many times this task has been attempted (0 for the first attempt)
   * @return the result of the task along with updates of Accumulators.
   */
  final def run(
    taskAttemptId: Long,
    attemptNumber: Int,
    metricsSystem: MetricsSystem)
  : (T, AccumulatorUpdates) = {
    //任務執行環境資訊
    context = new TaskContextImpl(
      stageId,
      partitionId,
      taskAttemptId,
      attemptNumber,
      taskMemoryManager,
      metricsSystem,
      internalAccumulators,
      runningLocally = false)
    TaskContext.setTaskContext(context)
    context.taskMetrics.setHostname(Utils.localHostName())
    context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators)
    taskThread = Thread.currentThread()
    if (_killed) {
      kill(interruptThread = false)
    }
    try {
     //呼叫runTask方法執行,不同的任務其實現不同,例如ShuffleMapTask和ResultTask其runTask方法邏輯不同
      (runTask(context), context.collectAccumulators())
    } finally {
      context.markTaskCompleted()
      try {
        Utils.tryLogNonFatalError {
          // Release memory used by this thread for shuffles
          SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask()
        }
        Utils.tryLogNonFatalError {
          // Release memory used by this thread for unrolling blocks
          SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask()
        }
      } finally {
        TaskContext.unset()
      }
    }
  }

以ResultTask為例,其runTask方法原始碼如下:

//ResultTask中的runTask方法
  override def runTask(context: TaskContext): U = {
    // Deserialize the RDD and the func using the broadcast variables.
    val deserializeStartTime = System.currentTimeMillis()
    val ser = SparkEnv.get.closureSerializer.newInstance()
    //反序列化rdd及執行函式
    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime

    metrics = Some(context.taskMetrics)
    //執行rdd.iterator方法,完成任務的計算
    func(context, rdd.iterator(partition, context))
  }

總結一下Task的執行過程:
1 呼叫Driver端org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend中的launchTasks
2 呼叫Worker端的org.apache.spark.executor.CoarseGrainedExecutorBackend.launchTask
3 執行org.apache.spark.executor.TaskRunner執行緒中的run方法
4 呼叫org.apache.spark.scheduler.Task.run方法
5 呼叫org.apache.spark.scheduler.ResultTask.runTask方法
6 呼叫org.apache.spark.rdd.RDD.iterator方法