Spark源码分析(2) 阶段划分和调度机制探究

1,170 阅读10分钟

0 概述

一个 Spark 应用 程序包括 Job、Stage 以及 Task 三个概念: 1 Job 是以 Action 方法为界,遇到一个 Action 方法则触发一个 Job;

2 Stage 是 Job 的子集,以 RDD 宽依赖(即 Shuffle)为界,遇到 Shuffle 做一次划分;

3 Task 是 Stage 的子集,以并行度(分区数)来衡量,分区数是多少,则有多少个 task。 Spark 的任务调度总体来说分两路进行,一路是 Stage 级的调度,一路是 Task 级的调度,

Spark RDD 通过其 Transactions 操作,形成了 RDD 血缘(依赖)关系图,即 DAG,最 后通过 Action 的调用,触发 Job 并调度执行,执行过程中会创建两个调度器:DAGScheduler 和 TaskScheduler。 DAGScheduler 负责 Stage 级的调度,主要是将 job 切分成若干 Stages,并将每个 Stage 打包成 TaskSet 交给 TaskScheduler 调度。 TaskScheduler 负责 Task 级的调度,将 DAGScheduler 给过来的 TaskSet 按照指定的调度 策略分发到 Executor 上执行,调度过程中 SchedulerBackend 负责提供可用资源,其中 SchedulerBackend 有多种实现,分别对接不同的资源管理系统。

TaskScheduler 负责 Task 级的调度,将 DAGScheduler 给过来的 TaskSet 按照指定的调度 策略分发到 Executor 上执行,调度过程中 SchedulerBackend 负责提供可用资源,其中 SchedulerBackend 有多种实现,分别对接不同的资源管理系统。

Driver 初始化 SparkContext 过程中,会分别初始化 DAGScheduler、TaskScheduler、 SchedulerBackend 以及 HeartbeatReceiver,并启动 SchedulerBackend 以及 HeartbeatReceiver。 SchedulerBackend 通过 ApplicationMaster 申请资源,并不断从 TaskScheduler 中拿到合适的 Task 分发到 Executor 执行。HeartbeatReceiver 负责接收 Executor 的心跳信息,监控 Executor 的存活状况,并通知到 TaskScheduler。 (侵删) SparkContext 的主要结构:

private var _conf: SparkConf = _
private var _eventLogDir: Option[URI] = None
private var _eventLogCodec: Option[String] = None
private var _env: SparkEnv = _
private var _jobProgressListener: JobProgressListener = _
private var _statusTracker: SparkStatusTracker = _
private var _progressBar: Option[ConsoleProgressBar] = None
private var _ui: Option[SparkUI] = None
private var _hadoopConfiguration: Configuration = _
private var _executorMemory: Int = _
private var _schedulerBackend: SchedulerBackend = _
private var _taskScheduler: TaskScheduler = _
private var _heartbeatReceiver: RpcEndpointRef = _
@volatile private var _dagScheduler: DAGScheduler = _
private var _applicationId: String = _
private var _applicationAttemptId: Option[String] = None
private var _eventLogger: Option[EventLoggingListener] = None
private var _executorAllocationManager: Option[ExecutorAllocationManager] = None
private var _cleaner: Option[ContextCleaner] = None
private var _listenerBusStarted: Boolean = false
private var _jars: Seq[String] = _
private var _files: Seq[String] = _
private var _shutdownHookRef: AnyRef = _

重要的参数

SparkConf 配置对象

SparkEnv 环境对象 通信环境

SchedulerBackend 通信后台主要用于Excutor 之间的通讯

TaskScheduler 任务切分

DAGScheduler 阶段调度器的划分

1 RDD的划分

public static void main(String[] args) throws Exception {

  if (args.length < 1) {
    System.err.println("Usage: JavaWordCount <file>");
    System.exit(1);
  }

  SparkSession spark = SparkSession
    .builder()
    .appName("JavaWordCount")
    .getOrCreate();

  JavaRDD<String> lines = spark.read().textFile(args[0]).javaRDD();

  JavaRDD<String> words = lines.flatMap(s -> Arrays.asList(SPACE.split(s)).iterator());

  JavaPairRDD<String, Integer> ones = words.mapToPair(s -> new Tuple2<>(s, 1));

  JavaPairRDD<String, Integer> counts = ones.reduceByKey((i1, i2) -> i1 + i2);

  List<Tuple2<String, Integer>> output = counts.collect();
  for (Tuple2<?,?> tuple : output) {
    System.out.println(tuple._1() + ": " + tuple._2());
  }
  spark.stop();
}

以最简单的WordCount为例:我们先点到 textFile 源码中去 可以看到 可以看到里面其实调用了map 函数

def textFile(
    path: String,
    minPartitions: Int = defaultMinPartitions): RDD[String] = withScope {
  assertNotStopped()
  hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
    minPartitions).map(pair => pair._2.toString).setName(path)
}

点到map 函数中去

def map[U: ClassTag](f: T => U): RDD[U] = withScope {
  val cleanF = sc.clean(f)
  new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF))
}

可以发现 其实里面有一个 MapPartitionsRDD的代码 所以返回的是个 MapPartitionsRDD的对象 同理我们点到flatMap 函数中去可以发现 同样的 也是 返回一个 MapPartitionsRDD 我们看一下 MapPartitionsRDD 这个函数

private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
    var prev: RDD[T],
    f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, partition index, iterator)
    preservesPartitioning: Boolean = false)
  extends RDD[U](prev) {

  override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None

  override def getPartitions: Array[Partition] = firstParent[T].partitions

  override def compute(split: Partition, context: TaskContext): Iterator[U] =
    f(context, split.index, firstParent[T].iterator(split, context))

  override def clearDependencies() {
    super.clearDependencies()
    prev = null
  }
}

我们可以看到 继承了一个抽象类RDD 继续往里面看 会发现 OneToOneDependency 通过这个函数可以看到两个RDD的依赖关系 我们再点一下

@DeveloperApi
class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
  override def getParents(partitionId: Int): List[Int] = List(partitionId)
}

可以发现这个地方是我们所说的窄依赖

我们继续点 reduceByKey 一直往里面点 可以发现

def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope {
  combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner)
}

里面的调用的还是Aggregator的方法 我们发现 会调用一个

new ShuffledRDD[K, V, C](self, partitioner)
    .setSerializer(serializer)
    .setAggregator(aggregator)
    .setMapSideCombine(mapSideCombine)
}

的东西 我们继续点进去 可以发现

override def getDependencies: Seq[Dependency[_]] = {
  val serializer = userSpecifiedSerializer.getOrElse {
    val serializerManager = SparkEnv.get.serializerManager
    if (mapSideCombine) {
      serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[C]])
    } else {
      serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[V]])
    }
  }
  List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine))
}

可以看到 ShuffleDependency 我们常说的宽依赖和窄依赖就是在此

2 stage的划分

我们点到 Collect 里面去 可以看到调用了runJob 函数

def collect(): Array[T] = withScope {
  val results = sc.(this, (iter: Iterator[T]) => iter.toArray)
  Array.concat(results: _*)
}

点到 runJob 函数一直点

def runJob[T, U: ClassTag](
    rdd: RDD[T],
    func: (TaskContext, Iterator[T]) => U,
    partitions: Seq[Int],
    resultHandler: (Int, U) => Unit): Unit = {
  if (stopped.get()) {
    throw new IllegalStateException("SparkContext has been shutdown")
  }
  val callSite = getCallSite
  val cleanedFunc = clean(func)
  logInfo("Starting job: " + callSite.shortForm)
  if (conf.getBoolean("spark.logLineage", false)) {
    logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString)
  }
  dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)
  progressBar.foreach(_.finishAll())
  rdd.doCheckpoint()
}

可以发现里面调用的是 dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get) 也就是我们说的的dag调度器来执行

我们点进去看

def runJob[T, U](
    rdd: RDD[T],
    func: (TaskContext, Iterator[T]) => U,
    partitions: Seq[Int],
    callSite: CallSite,
    resultHandler: (Int, U) => Unit,
    properties: Properties): Unit = {
  val start = System.nanoTime
  val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
  ThreadUtils.awaitReady(waiter.completionFuture, Duration.Inf)
  waiter.completionFuture.value.get match {
    case scala.util.Success(_) =>
      logInfo("Job %d finished: %s, took %f s".format
        (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
    case scala.util.Failure(exception) =>
      logInfo("Job %d failed: %s, took %f s".format
        (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
      // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
      val callerStackTrace = Thread.currentThread().getStackTrace.tail
      exception.setStackTrace(exception.getStackTrace ++ callerStackTrace)
      throw exception
  }
}

可以看到submitJob这个函数

def submitJob[T, U](
    rdd: RDD[T],
    func: (TaskContext, Iterator[T]) => U,
    partitions: Seq[Int],
    callSite: CallSite,
    resultHandler: (Int, U) => Unit,
    properties: Properties): JobWaiter[U] = {
  // Check to make sure we are not launching a task on a partition that does not exist.
  val maxPartitions = rdd.partitions.length
  partitions.find(p => p >= maxPartitions || p < 0).foreach { p =>
    throw new IllegalArgumentException(
      "Attempting to access a non-existent partition: " + p + ". " +
        "Total number of partitions: " + maxPartitions)
  }

  val jobId = nextJobId.getAndIncrement()
  if (partitions.size == 0) {
    // Return immediately if the job is running 0 tasks
    return new JobWaiter[U](this, jobId, 0, resultHandler)
  }

  assert(partitions.size > 0)
  val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
  val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
  eventProcessLoop.post(JobSubmitted(
    jobId, rdd, func2, partitions.toArray, callSite, waiter,
    SerializationUtils.clone(properties)))
  waiter
}

这个中

eventProcessLoop.post(JobSubmitted(
  jobId, rdd, func2, partitions.toArray, callSite, waiter,
  SerializationUtils.clone(properties)))

关键在于这个post我们点进去看到

def post(event: E): Unit = {
  eventQueue.put(event)
}

是在一个事件队列中加上了这么一个队列 我们重点看eventQueue中去,可以看到下面这么一段 重点看从队列中拿出事件然后调用onReceive

private val eventThread = new Thread(name) {
  setDaemon(true)

  override def run(): Unit = {
    try {
      while (!stopped.get) {
        val event = eventQueue.take()
        try {
          onReceive(event)
        } catch {
          case NonFatal(e) =>
            try {
              onError(e)
            } catch {
              case NonFatal(e) => logError("Unexpected error in " + name, e)
            }
        }
      }
    } catch {
      case ie: InterruptedException => // exit even if eventQueue is not empty
      case NonFatal(e) => logError("Unexpected error in " + name, e)
    }
  }

}

点进去看 onReceive方法里面可以看到调用的是方法doOnReceive 然后我们继续看到这个里面是个模式匹配,根据我们刚刚匹配的 JobSubmitted这个方法

private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
  case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) =>
    dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties)

  case MapStageSubmitted(jobId, dependency, callSite, listener, properties) =>
    dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties)

  case StageCancelled(stageId, reason) =>
    dagScheduler.handleStageCancellation(stageId, reason)

  case JobCancelled(jobId, reason) =>
    dagScheduler.handleJobCancellation(jobId, reason)

  case JobGroupCancelled(groupId) =>
    dagScheduler.handleJobGroupCancelled(groupId)

  case AllJobsCancelled =>
    dagScheduler.doCancelAllJobs()

  case ExecutorAdded(execId, host) =>
    dagScheduler.handleExecutorAdded(execId, host)

  case ExecutorLost(execId, reason) =>
    val filesLost = reason match {
      case SlaveLost(_, true) => true
      case _ => false
    }
    dagScheduler.handleExecutorLost(execId, filesLost)

  case BeginEvent(task, taskInfo) =>
    dagScheduler.handleBeginEvent(task, taskInfo)

  case GettingResultEvent(taskInfo) =>
    dagScheduler.handleGetTaskResult(taskInfo)

  case completion: CompletionEvent =>
    dagScheduler.handleTaskCompletion(completion)

  case TaskSetFailed(taskSet, reason, exception) =>
    dagScheduler.handleTaskSetFailed(taskSet, reason, exception)

  case ResubmitFailedStages =>
    dagScheduler.resubmitFailedStages()
}

我们点到 handleJobSubmitted中去 ,里面的核心方法是 createResultStage

private def createResultStage(
    rdd: RDD[_],
    func: (TaskContext, Iterator[_]) => _,
    partitions: Array[Int],
    jobId: Int,
    callSite: CallSite): ResultStage = {
  val parents = getOrCreateParentStages(rdd, jobId)
  val id = nextStageId.getAndIncrement()
  val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite)
  stageIdToStage(id) = stage
  updateJobIdStageIdMaps(jobId, stage)
  stage
}

点进去看 这个地方就是我们常说的阶段划分里面有个重要的方法:getOrCreateParentStages 核心代码可以看到

private def getOrCreateParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {
  getShuffleDependencies(rdd).map { shuffleDep =>
    getOrCreateShuffleMapStage(shuffleDep, firstJobId)
  }.toList
}

我们点到 getShuffleDependencies 这个里面去可以发现就是为了拿到上一层的依赖关系

private[scheduler] def getShuffleDependencies(
    rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = {
  val parents = new HashSet[ShuffleDependency[_, _, _]]
  val visited = new HashSet[RDD[_]]
  val waitingForVisit = new Stack[RDD[_]]
  waitingForVisit.push(rdd)
  while (waitingForVisit.nonEmpty) {
    val toVisit = waitingForVisit.pop()
    if (!visited(toVisit)) {
      visited += toVisit
      toVisit.dependencies.foreach {
        case shuffleDep: ShuffleDependency[_, _, _] =>
          parents += shuffleDep
        case dependency =>
          waitingForVisit.push(dependency.rdd)
      }
    }
  }
  parents
}

我们看重要方法:getOrCreateShuffleMapStage这个方法名就可以知道是创建ShuffleMapStage 。刚刚我们看到 createResultStage 这个最终得到的是一个 ResultStageResultStage 里面则包含了ShuffleMapStage 如图 (侵删)

image.png

2.0 总结

Spark 的任务调度是从 DAG 切割开始,主要是由 DAGScheduler 来完成。当遇到一个 Action 操作后就会触发一个 Job 的计算,并交给 DAGScheduler 来提交,下图是涉及到 Job 提交的相关方法调用流程图

image.png(侵删)

  1. Job 由最终的 RDD 和 Action 方法封装而成;

  2. SparkContext 将 Job 交给 DAGScheduler 提交,它会根据 RDD 的血缘关系构成的 DAG 进行切分,将一个 Job 划分为若干 Stages,具体划分策略是,由最终的 RDD 不断通过 依赖回溯判断父依赖是否是宽依赖,即以 Shuffle 为界,划分 Stage,窄依赖的 RDD 之 间被划分到同一个 Stage 中,可以进行 pipeline 式的计算。划分的 Stages 分两类,一类 叫做 ResultStage,为 DAG 最下游的 Stage,由 Action 方法决定,另一类叫做 ShuffleMapStage,为下游 Stage 准备数据

3 任务的划分

我们可以上文代码中 创建完成 ResultStage 之后 submitStage

finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite)

...

submitStage(finalStage)

进去可以看到 根据 stage的条件来判断 我们看一下 getMissingParentStages 这个方法 这个就是判断stage 是否存在上一个集合 如果存在上一级 就递归提交上一级的阶段 ,否则就 submitMissingTasks(stage, jobId.get) 提交这一阶段的任务

private def submitStage(stage: Stage) {
  val jobId = activeJobForStage(stage)
  if (jobId.isDefined) {
    logDebug("submitStage(" + stage + ")")
    if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
      val missing = getMissingParentStages(stage).sortBy(_.id)
      logDebug("missing: " + missing)
      if (missing.isEmpty) {
        logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
        submitMissingTasks(stage, jobId.get)
      } else {
        for (parent <- missing) {
          submitStage(parent)
        }
        waitingStages += stage
      }
    }
  } else {
    abortStage(stage, "No active job for stage " + stage.id, None)
  }
}

点到 submitStage 这个里面去 可以看到关键死 使用了一个模式匹配来 根据不同的stage 来创建这个地方有个关键点 partitionsToCompute 调用了map 所以 我们先点到 partitionsToCompute里面去 调用了 stage.findMissingPartitions()这个方法有两个实现类, image.png 因为返回的都是列表 里面包含的都是id 所以最后 可以得到这样一个结论:每个stage 是根据最好一个rdd 的分区的数量来的 确定 task的数量

val tasks: Seq[Task[_]] = try {
  val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()
  stage match {
    case stage: ShuffleMapStage =>
      stage.pendingPartitions.clear()
      partitionsToCompute.map { id =>
        val locs = taskIdToLocations(id)
        val part = stage.rdd.partitions(id)
        stage.pendingPartitions += id
        new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
          taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
          Option(sc.applicationId), sc.applicationAttemptId)
      }

    case stage: ResultStage =>
      partitionsToCompute.map { id =>
        val p: Int = stage.partitions(id)
        val part = stage.rdd.partitions(p)
        val locs = taskIdToLocations(id)
        new ResultTask(stage.id, stage.latestInfo.attemptId,
          taskBinary, part, locs, id, properties, serializedTaskMetrics,
          Option(jobId), Option(sc.applicationId), sc.applicationAttemptId)
      }
  }
}

我们发现会有一个 taskScheduler 来 submitTasks

taskScheduler.submitTasks(new TaskSet(
  tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties))

我们看一下这个方法的实现类

override def submitTasks(taskSet: TaskSet) {
  val tasks = taskSet.tasks
  logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
  this.synchronized {
    val manager = createTaskSetManager(taskSet, maxTaskFailures)
    val stage = taskSet.stageId
    val stageTaskSets =
      taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
    stageTaskSets(taskSet.stageAttemptId) = manager
    val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
      ts.taskSet != taskSet && !ts.isZombie
    }
    if (conflictingTaskSet) {
      throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
        s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
    }
    schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

    if (!isLocal && !hasReceivedTask) {
      starvationTimer.scheduleAtFixedRate(new TimerTask() {
        override def run() {
          if (!hasLaunchedTask) {
            logWarning("Initial job has not accepted any resources; " +
              "check your cluster UI to ensure that workers are registered " +
              "and have sufficient resources")
          } else {
            this.cancel()
          }
        }
      }, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS)
    }
    hasReceivedTask = true
  }
  backend.reviveOffers()
}

我们继续往里面点 有一个

val manager = createTaskSetManager(taskSet, maxTaskFailures)

的方法 然后 我们来到

schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

关键的地方来了 这个任务调度器的构造者 我们可以在

def initialize(backend: SchedulerBackend) {
  this.backend = backend
  schedulableBuilder = {
    schedulingMode match {
      case SchedulingMode.FIFO =>
        new FIFOSchedulableBuilder(rootPool)
      case SchedulingMode.FAIR =>
        new FairSchedulableBuilder(rootPool, conf)
      case _ =>
        throw new IllegalArgumentException(s"Unsupported $SCHEDULER_MODE_PROPERTY: " +
        s"$schedulingMode")
    }
  }
  schedulableBuilder.buildPools()
}

这个里面看到 里面有两个关键的方法 FIFOSchedulableBuilderFairSchedulableBuilder

先点到 FIFOSchedulableBuilder 中去

private[spark] class FIFOSchedulableBuilder(val rootPool: Pool)
  extends SchedulableBuilder with Logging {

  override def buildPools() {
    // nothing
  }

  override def addTaskSetManager(manager: Schedulable, properties: Properties) {
    rootPool.addSchedulable(manager)
  }
}

可以看到 有一个 rootPool的东西 这个东西我查阅了一下资料大概是这样的

image.png

我们继续看

backend.reviveOffers()

这个是个抽象方法 image.png

我们点到集群模式:

override def reviveOffers() {
  driverEndpoint.send(ReviveOffers)
}

继续找 ReviveOffers

里面有个

case ReviveOffers =>
  makeOffers()

这个 makeOffer 是为了得到 taskDescs 任务的描述信息

// Make fake resource offers on all executors
private def makeOffers() {
  // Make sure no executor is killed while some task is launching on it
  val taskDescs = CoarseGrainedSchedulerBackend.this.synchronized {
    // Filter out executors under killing
    val activeExecutors = executorDataMap.filterKeys(executorIsAlive)
    val workOffers = activeExecutors.map { case (id, executorData) =>
      new WorkerOffer(id, executorData.executorHost, executorData.freeCores)
    }.toIndexedSeq
    scheduler.resourceOffers(workOffers)
  }
  if (!taskDescs.isEmpty) {
    launchTasks(taskDescs)
  }
}

最后使用 launchTasks 来提交方法到Executor

private def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
  for (task <- tasks.flatten) {
    val serializedTask = TaskDescription.encode(task)
    if (serializedTask.limit >= maxRpcMessageSize) {
      scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
        try {
          var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
            "spark.rpc.message.maxSize (%d bytes). Consider increasing " +
            "spark.rpc.message.maxSize or using broadcast variables for large values."
          msg = msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize)
          taskSetMgr.abort(msg)
        } catch {
          case e: Exception => logError("Exception in error callback", e)
        }
      }
    }
    else {
      val executorData = executorDataMap(task.executorId)
      executorData.freeCores -= scheduler.CPUS_PER_TASK

      logDebug(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " +
        s"${executorData.executorHost}.")

      executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
    }
  }
}

3.0 总结

Spark Task 的调度是由 TaskScheduler 来完成,由前文可知,DAGScheduler 将 Stage 打 包到交给 TaskScheTaskSetduler,TaskScheduler 会将 TaskSet 封装为 TaskSetManager 加入到 调度队列中,TaskSetManager 结构如下图所示。 TaskSetManager 负 责监控 管理 同一 个 Stage 中的 Tasks, TaskScheduler 就是以 TaskSetManager 为单元来调度任务。 前面也提到,TaskScheduler 初始化后会启动 SchedulerBackend,它负责跟外界打交道, 接收 Executor 的注册信息,并维护 Executor 的状态,所以说 SchedulerBackend 是管“粮食” 的,同时它在启动后会定期地去“询问”TaskScheduler 有没有任务要运行,也就是说,它会定 期地“问”TaskScheduler“我有这么余粮,你要不要啊”,TaskScheduler 在 SchedulerBackend“问” 它的时候,会从调度队列中按照指定的调度策略选择 TaskSetManager 去调度运行,大致方 法调用流程如下图所示:

image.png

上图中,将 TaskSetManager 加入 rootPool 调度池中之后,调用 SchedulerBackend 的 riviveOffers 方法给 driverEndpoint 发送 ReviveOffer 消息;driverEndpoint 收到 ReviveOffer 消 息后调用 makeOffers 方法,过滤出活跃状态的 Executor(这些 Executor 都是任务启动时反 向注册到 Driver 的 Executor),然后将 Executor 封装成 WorkerOffer 对象;准备好计算资源 (WorkerOffer)后,taskScheduler 基于这些资源调用 resourceOffer 在 Executor 上分配 task。