Spark 源码分析(七): DAGScheduler 源码分析1(stage 划分算法)

3,451 阅读6分钟

前面几篇文章已经说清楚了从 spark 任务提交到 driver 启动,然后执行 main 方法,初始化 SparkContext 对象。

在初始化 SparkContext 对象的过程中创建了两个重要组件:

一个是 TaskScheduler(实际上是他的实现类 TaskSchedulerImpl 对象),这个对象内部会持有一个 SchedulerBackend 对象,SchedulerBackend 内部会又会持有一个 DriverEndpoint 对象(实际上就是一个 RpcEndpoint)。这样 TaskScheduler 就可以通过 SchedulerBackend 和集群资源管理器或者 Executor 对应 worker 节点进行通信做一些事情。比如向 master 节点去注册 application,master 在注册 application 的过程中会分配 worker 去启动 Executor,当 Executor 启动后又会和 TaskScheduler 进行注册。

另一个是 DAGScheduler,关于这个对象的创建过程前面没有详细讲,主要是因为 DAGScheduler 是在 SparkContext 初始化结束后,执行到 RDD 的 Action 操作的时候才会开始工作,下面就从 RDD 的 action 操作说起,看看 DAGScheduler 是怎么工作的。

还是以 wordcount 程序为例:

		val conf = new SparkConf()
      .setAppName("WordCount")
      .setMaster("local")
    val sc = new SparkContext(conf)
    val lines = sc.textFile("./file/localfile")
    val words = lines.flatMap(line => line.split(" "))
    val wordPairs = words.map(word => (word, 1))
    val wordCounts = wordPairs.reduceByKey(_ + _)
    wordCounts.foreach(wordCount => println(wordCount._1 + "  " + wordCount._2))

当代码执行到 wordCounts.foreach 时候会调用到 RDD 的 foreach 方法,RDD 的 foreach 方法会去调用 SparkContext 的 runjob 方法。

SparkContext 中会有多个 runjob 方法,最后都会走到一个 runjob 那里去,这个 runjob 方法最终会调用 DAGScheduler 的 runJob 的方法,具体可以先看下这个 SparkContext 的 runjob 方法。

def runJob[T, U: ClassTag](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      resultHandler: (Int, U) => Unit): Unit = {
    if (stopped.get()) {
      throw new IllegalStateException("SparkContext has been shutdown")
    }
    val callSite = getCallSite
    val cleanedFunc = clean(func)
    logInfo("Starting job: " + callSite.shortForm)
    if (conf.getBoolean("spark.logLineage", false)) {
      logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString)
    }
  	// 去调用 DAGScheduler 的 runjob 方法
    dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)
    progressBar.foreach(_.finishAll())
    rdd.doCheckpoint()
  }

最主要的还是 DAGScheduler 中的 runjob 方法。

这个 runjob 方法内部实际上调用了 submitJob 方法,用于提交 job。该方法返回一个 JobWaiter,用于等待 DAGScheduler 任务的完成。

def runJob[T, U](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      callSite: CallSite,
      resultHandler: (Int, U) => Unit,
      properties: Properties): Unit = {
    val start = System.nanoTime
    // 调用 submitJob 方法
    val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)

submitJob 方法是调用 eventProcessLoop 的 post 方法将 JobSubmitted 事件添加到 DAGScheduler 的事件队列中去。

eventProcessLoop.post(JobSubmitted(
      jobId, rdd, func2, partitions.toArray, callSite, waiter,
      SerializationUtils.clone(properties)))

这里的 eventProcessLoop 是 DAGSchedulerEventProcessLoop 对象,在 DAGScheduler 的初始化代码中可以看到。DAGSchedulerEventProcessLoop 实际上内部有一个线程,用来处理事件队列。

事件队列的处理最后会走到 DAGSchedulerEventProcessLoop 的 onReceive 的回调方法里面去。

/**
   * The main event loop of the DAG scheduler.
   */
  override def onReceive(event: DAGSchedulerEvent): Unit = {
    val timerContext = timer.time()
    try {
      // 调用 doOnReceive 方法
      doOnReceive(event)
    } finally {
      timerContext.stop()
    }
  }

后面会去调用 doOnReceive 方法,根据 event 进行模式匹配,匹配到 JobSubmitted 的 event 后实际上是去调用 DAGScheduler 的 handleJobSubmitted 这个方法。

private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
  	// 模式匹配
    case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) =>
      // 调用 handleJobSubmitted 方法
      dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties)

下面来看 handleJobSubmitted 这个方法做了哪些操作:

1,使用触发 job 的最后一个 rdd,创建 finalStage;

注: Stage 是一个抽象类,一共有两个实现,一个是 ResultStage,是用 action 中的函数计算结果的 stage;另一个是 ShuffleMapStage,是为 shuffle 准备数据的 stage。

2,构造一个 Job 对象,将上面创建的 finalStage 封装进去,这个 Job 的最后一个 stage 也就是这个 finalStage;

3,将 Job 的相关信息保存到内存的数据结构中;

4,调用 submitStage 方法提交 finalStage。

private[scheduler] def handleJobSubmitted(jobId: Int,
      finalRDD: RDD[_],
      func: (TaskContext, Iterator[_]) => _,
      partitions: Array[Int],
      callSite: CallSite,
      listener: JobListener,
      properties: Properties) {
    var finalStage: ResultStage = null
    try {
      // 使用触发 job 的最后一个 RDD 创建一个 ResultStage
      finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite)
    } catch {
      case e: Exception =>
        logWarning("Creating new stage failed due to exception - job: " + jobId, e)
        listener.jobFailed(e)
        return
    }

  	// 使用前面创建好的 ResultStage 去创建一个 job
  	// 这个 job 的最后一个 stage 就是 finalStage
    val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)
    clearCacheLocs()
    logInfo("Got job %s (%s) with %d output partitions".format(
      job.jobId, callSite.shortForm, partitions.length))
    logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")")
    logInfo("Parents of final stage: " + finalStage.parents)
    logInfo("Missing parents: " + getMissingParentStages(finalStage))

    // 将 job 的相关信息存储到内存中
    val jobSubmissionTime = clock.getTimeMillis()
    jobIdToActiveJob(jobId) = job
    activeJobs += job
    finalStage.setActiveJob(job)
    val stageIds = jobIdToStageIds(jobId).toArray
    val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
    listenerBus.post(
      SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
    // 提交 finalStage
    submitStage(finalStage)
  }

下面就会走进 submitStage 方法,这个方法是用来提交 stage 的,具体做了这些操作:

1,首先会验证 stage 对应的 job id 进行校验,存在才会继续执行;

2,在提交这个 stage 之前会判断当前 stage 的状态。

如果是 running、waiting、failed 的话就不做任何操作。

如果不是这三个状态则会根据当前 stage 去往前推前面的 stage,如果能找到前面的 stage 则继续递归调用 submitStage 方法,直到当前 stage 找不到前面的 stage 为止,这时候的 stage 就相当于当前 job 的第一个 stage,然后回去调用 submitMissingTasks 方法去分配 task。

private def submitStage(stage: Stage) {
    val jobId = activeJobForStage(stage)
    // 看看当前的 job 是否存在
    if (jobId.isDefined) {
      logDebug("submitStage(" + stage + ")")
       // 判断当前 stage 的状态
      if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
        // 根据当前的 stage 去推倒前面的 stage
        val missing = getMissingParentStages(stage).sortBy(_.id)
        logDebug("missing: " + missing)
        // 如果前面已经没有 stage 了,那么久将当前 stage 去执行 submitMissingTasks 方法
        // 如果前面还有 stage 的话那么递归调用 submitStage
        if (missing.isEmpty) {
          logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
          submitMissingTasks(stage, jobId.get)
        } else {
          for (parent <- missing) {
            submitStage(parent)
          }
          // 将当前 stage 加入等待队列
          waitingStages += stage
        }
      }
    } else {
      // abortStage 终止提交当前 stage
      abortStage(stage, "No active job for stage " + stage.id, None)
    }
  }

上面最重要的一个地方就是使用当前 stage 向前推,找到前面的 stage,也是 stage 的划分算法。下面就看看 getMissingParentStages 这个划分算法做了哪些操作:

1,创建 missing 和 visited 两个 HashSet,分别用来存储根据当前 stage 向前找到的所有 stage 数据和已经调用过 visit 方法的 RDD;

2,创建一个存放 RDD 的栈,然后将传进来的 stage 中的 rdd 也就是 finalStage 中的那个 job 触发的最后一个 RDD 放入栈中;

3,然后将栈中的 RDD 拿出来调用 visit 方法,这个 visit 方法内部会根据当前 RDD 的依赖链逐个遍历所有 RDD,并且会根据相邻两个 RDD 的依赖关系来决定下面的操作:

如果是宽依赖,即 ShuffleDependency ,那么会调用 getOrCreateShuffleMapStage 创建一个新的 stage,默认每个 job 的最后一个 stage 是 ResultStage,剩余的 job 中的其它 stage 均为 ShuffleMapStage。然后会将创建的这个 stage 加入前面创建的 missing 的 HashSet 中;

如果是窄依赖,即 NarrowDependency,那么会将该 RDD 加入到前面创建的 RDD 栈中,继续遍历调用 visit 方法。

直到所有的 RDD 都遍历结束后返回前面创建的 missing 的集合。

private def getMissingParentStages(stage: Stage): List[Stage] = {
  	// 存放下面找到的所有 stage
    val missing = new HashSet[Stage]
    // 存放已经遍历过的 rdd
    val visited = new HashSet[RDD[_]]
    // We are manually maintaining a stack here to prevent StackOverflowError
    // caused by recursively visiting
    // 创建一个维护 RDD 的栈
    val waitingForVisit = new Stack[RDD[_]]
    // visit 方法
    def visit(rdd: RDD[_]) {
      // 判断当前 rdd 是否 visit 过
      if (!visited(rdd)) {
        visited += rdd
        val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil)
        if (rddHasUncachedPartitions) {
          // 遍历当前 RDD 的依赖链
          for (dep <- rdd.dependencies) {
            dep match {
              // 如果是宽依赖
              case shufDep: ShuffleDependency[_, _, _] =>
                // 创建 ShuffleMapStage 
                val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId)
                if (!mapStage.isAvailable) {
                  // 加入 missing 集合
                  missing += mapStage
                }
              // 如果是窄依赖
              case narrowDep: NarrowDependency[_] =>
                // 加入等待 visit 的集合中,准备下一次遍历
                waitingForVisit.push(narrowDep.rdd)
            }
          }
        }
      }
    }
  	// 将传入的 stage 中的 rdd 拿出来压入 waitingForVisit 的栈中
    waitingForVisit.push(stage.rdd)
    // 遍历栈里的所有 RDD 
    while (waitingForVisit.nonEmpty) {
      // 调用 visit 方法
      visit(waitingForVisit.pop())
    }
  	// 返回 missing 这个 stage 集合
    missing.toList
  }

至此,所有的 stage 都已经划分结束了。可以看出每个 Spark Application 执行代码的时候,每当碰到一个 Action 操作就会划分出一个 Job,然后每个 Job 里会根据宽窄依赖去划分出多个 stage。