Spark修煉之道(高階篇)——Spark原始碼閱讀:第八節 Task執行
Task執行
在上一節中,我們提到在Driver端CoarseGrainedSchedulerBackend中的launchTasks方法向Worker節點中的Executor傳送啟動任務命令,該命令的接收者是CoarseGrainedExecutorBackend(Standalone模式),類定義原始碼如下:
private[spark] class CoarseGrainedExecutorBackend(
override val rpcEnv: RpcEnv,
driverUrl: String,
executorId: String,
hostPort: String ,
cores: Int,
userClassPath: Seq[URL],
env: SparkEnv)
extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {
可以看到它繼承ThreadSafeRpcEndpoint,它ThreadSafeRpcEndpoint中的receive方法進行了實現,具體原始碼如下:
override def receive: PartialFunction[Any, Unit] = {
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
val (hostname, _) = Utils.parseHostPort(hostPort)
executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
System.exit(1 )
//處理Driver端傳送過來的LaunchTask命令
case LaunchTask(data) =>
if (executor == null) {
logError("Received LaunchTask command but executor was null")
System.exit(1)
} else {
//對任務進行反序列化
val taskDesc = ser.deserialize[TaskDescription](data.value)
logInfo("Got assigned task " + taskDesc.taskId)
//Executor啟動任務的執行
executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
taskDesc.name, taskDesc.serializedTask)
}
case KillTask(taskId, _, interruptThread) =>
if (executor == null) {
logError("Received KillTask command but executor was null")
System.exit(1)
} else {
executor.killTask(taskId, interruptThread)
}
case StopExecutor =>
logInfo("Driver commanded a shutdown")
executor.stop()
stop()
rpcEnv.shutdown()
}
從前面的程式碼可以看到,通過 executor.launchTask方法啟動Worker節點上Task的執行,其原始碼如下:
//Executor類中的launchTask方法
def launchTask(
context: ExecutorBackend,
taskId: Long,
attemptNumber: Int,
taskName: String,
serializedTask: ByteBuffer): Unit = {
//建立TaskRunner
val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
serializedTask)
runningTasks.put(taskId, tr)
//執行緒池執行TaskRunner執行緒,該執行緒中有一個run方法,完成Task的執行
threadPool.execute(tr)
}
TaskRunner是一個執行緒,它是一個內部類,被定義在org.apache.spark.executor.Executor類當中,具體原始碼如下:
class TaskRunner(
execBackend: ExecutorBackend,
val taskId: Long,
val attemptNumber: Int,
taskName: String,
serializedTask: ByteBuffer)
extends Runnable {
/** Whether this task has been killed. */
@volatile private var killed = false
/** How much the JVM process has spent in GC when the task starts to run. */
@volatile var startGCTime: Long = _
/**
* The task to run. This will be set in run() by deserializing the task binary coming
* from the driver. Once it is set, it will never be changed.
*/
@volatile var task: Task[Any] = _
def kill(interruptThread: Boolean): Unit = {
logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
killed = true
if (task != null) {
task.kill(interruptThread)
}
}
override def run(): Unit = {
val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
//向Driver端髮狀態更新
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
startGCTime = computeTotalGcTime()
try {
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
task.setTaskMemoryManager(taskMemoryManager)
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
if (killed) {
// Throw an exception rather than returning, because returning within a try{} block
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
throw new TaskKilledException
}
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch)
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
var threwException = true
val (value, accumUpdates) = try {
//執行Task的run方法,不同的Task有不同的實現,例如ShuffleMapTask及ResultTask有各自的實現
val res = task.run(
taskAttemptId = taskId,
attemptNumber = attemptNumber,
metricsSystem = env.metricsSystem)
threwException = false
res
} finally {
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
if (freedMemory > 0) {
val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) {
throw new SparkException(errMsg)
} else {
logError(errMsg)
}
}
}
val taskFinish = System.currentTimeMillis()
// If the task has been killed, let's fail it.
if (task.killed) {
throw new TaskKilledException
}
val resultSer = env.serializer.newInstance()
val beforeSerialization = System.currentTimeMillis()
val valueBytes = resultSer.serialize(value)
val afterSerialization = System.currentTimeMillis()
for (m <- task.metrics) {
// Deserialization happens in two parts: first, we deserialize a Task object, which
// includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
m.setExecutorDeserializeTime(
(taskStart - deserializeStartTime) + task.executorDeserializeTime)
// We need to subtract Task.run()'s deserialization time to avoid double-counting
m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
m.setJvmGCTime(computeTotalGcTime() - startGCTime)
m.setResultSerializationTime(afterSerialization - beforeSerialization)
m.updateAccumulators()
}
val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit
// directSend = sending directly back to the driver
val serializedResult: ByteBuffer = {
if (maxResultSize > 0 && resultSize > maxResultSize) {
logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
s"dropping it.")
ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
} else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
logInfo(
s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
serializedDirectResult
}
}
//執行完成後,通知Driver端進行狀態更新
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
case ffe: FetchFailedException =>
val reason = ffe.toTaskEndReason
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
case _: TaskKilledException | _: InterruptedException if task.killed =>
logInfo(s"Executor killed $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
case cDE: CommitDeniedException =>
val reason = cDE.toTaskEndReason
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
case t: Throwable =>
// Attempt to exit cleanly by informing the driver of our failure.
// If anything goes wrong (or this was a fatal exception), we will delegate to
// the default uncaught exception handler, which will terminate the Executor.
logError(s"Exception in $taskName (TID $taskId)", t)
val metrics: Option[TaskMetrics] = Option(task).flatMap { task =>
task.metrics.map { m =>
m.setExecutorRunTime(System.currentTimeMillis() - taskStart)
m.setJvmGCTime(computeTotalGcTime() - startGCTime)
m.updateAccumulators()
m
}
}
val serializedTaskEndReason = {
try {
ser.serialize(new ExceptionFailure(t, metrics))
} catch {
case _: NotSerializableException =>
// t is not serializable so just send the stacktrace
ser.serialize(new ExceptionFailure(t, metrics, false))
}
}
//任務失敗時,同樣進行狀態更新,方便後期任務重執行
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Utils.isFatalError(t)) {
SparkUncaughtExceptionHandler.uncaughtException(t)
}
} finally {
//從執行任務列表中刪除
runningTasks.remove(taskId)
}
}
}
Task run方法負責Task的執行,其原始碼如下:
/**
* Called by [[Executor]] to run this task.
*
* @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
* @param attemptNumber how many times this task has been attempted (0 for the first attempt)
* @return the result of the task along with updates of Accumulators.
*/
final def run(
taskAttemptId: Long,
attemptNumber: Int,
metricsSystem: MetricsSystem)
: (T, AccumulatorUpdates) = {
//任務執行環境資訊
context = new TaskContextImpl(
stageId,
partitionId,
taskAttemptId,
attemptNumber,
taskMemoryManager,
metricsSystem,
internalAccumulators,
runningLocally = false)
TaskContext.setTaskContext(context)
context.taskMetrics.setHostname(Utils.localHostName())
context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators)
taskThread = Thread.currentThread()
if (_killed) {
kill(interruptThread = false)
}
try {
//呼叫runTask方法執行,不同的任務其實現不同,例如ShuffleMapTask和ResultTask其runTask方法邏輯不同
(runTask(context), context.collectAccumulators())
} finally {
context.markTaskCompleted()
try {
Utils.tryLogNonFatalError {
// Release memory used by this thread for shuffles
SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask()
}
Utils.tryLogNonFatalError {
// Release memory used by this thread for unrolling blocks
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask()
}
} finally {
TaskContext.unset()
}
}
}
以ResultTask為例,其runTask方法原始碼如下:
//ResultTask中的runTask方法
override def runTask(context: TaskContext): U = {
// Deserialize the RDD and the func using the broadcast variables.
val deserializeStartTime = System.currentTimeMillis()
val ser = SparkEnv.get.closureSerializer.newInstance()
//反序列化rdd及執行函式
val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
metrics = Some(context.taskMetrics)
//執行rdd.iterator方法,完成任務的計算
func(context, rdd.iterator(partition, context))
}
總結一下Task的執行過程:
1 呼叫Driver端org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend中的launchTasks
2 呼叫Worker端的org.apache.spark.executor.CoarseGrainedExecutorBackend.launchTask
3 執行org.apache.spark.executor.TaskRunner執行緒中的run方法
4 呼叫org.apache.spark.scheduler.Task.run方法
5 呼叫org.apache.spark.scheduler.ResultTask.runTask方法
6 呼叫org.apache.spark.rdd.RDD.iterator方法