Spark计算引擎源码分析-Shuffle Read

366 阅读1分钟

SortShuffleManager.getReader()

override def getReader[K, C](
    handle: ShuffleHandle,
    startPartition: Int,
    endPartition: Int,
    context: TaskContext,
    metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
  val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
    handle.shuffleId, startPartition, endPartition)
  new BlockStoreShuffleReader(
    handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics,
    shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context))
}

获取map任务状态

MapOutputTrackerWorker.getMapSizesByExecutorId()

override def getMapSizesByExecutorId(
    shuffleId: Int,
    startPartition: Int,
    endPartition: Int)
  : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
  logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
  val statuses = getStatuses(shuffleId, conf)
  try {
    MapOutputTracker.convertMapStatuses(
      shuffleId, startPartition, endPartition, statuses)
  } catch {
    case e: MetadataFetchFailedException =>
      // We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
      mapStatuses.clear()
      throw e
  }
}

getStatus

获取指定 shuffleID 的mapstatus,如果本地没有,从远程MapOutputTrackerMaster获取

打印的日志:

logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)  
logInfo("Got the output locations")  
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)  

MapOutputTrackerWorker.getStatus()

private def getStatuses(shuffleId: Int, conf: SparkConf): Array[MapStatus] = {
  val statuses = mapStatuses.get(shuffleId).orNull
  if (statuses == null) {
    logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
    val startTimeNs = System.nanoTime()
    fetchingLock.withLock(shuffleId) {
      var fetchedStatuses = mapStatuses.get(shuffleId).orNull
      if (fetchedStatuses == null) {
        logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
        val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
        fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf)
        logInfo("Got the output locations")
        mapStatuses.put(shuffleId, fetchedStatuses)
      }
      logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
        s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
      fetchedStatuses
    }
  } else {
    statuses
  }
}

远程获取:

  1. trackerEndpoint发送消息GetMapOutputStatuses(shuffleId)
protected def askTracker[T: ClassTag](message: Any): T = {
    trackerEndpoint.askSync[T](message)
}
  1. MapOutputTrackerMasterEndpoint.receiveAndReply
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
  case GetMapOutputStatuses(shuffleId: Int) =>
    val hostPort = context.senderAddress.hostPort
    logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
    tracker.post(new GetMapOutputMessage(shuffleId, context))
}

调用tracker.post

  def post(message: GetMapOutputMessage): Unit = {
    mapOutputRequests.offer(message)
  }

mapOutputRequests加入GetMapOutputMessage(shuffleId, context)消息。这里的mapOutputRequests是链式阻塞队列。

  private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage]

MessageLoop启一个线程不断的参数从mapOutputRequests读取数据:MapOutputTrackerMaster.MessageLoop.run

private class MessageLoop extends Runnable {
  override def run(): Unit = {
    try {
      while (true) {
        try {
          val data = mapOutputRequests.take()
           if (data == PoisonPill) {
            // Put PoisonPill back so that other MessageLoops can see it.
            mapOutputRequests.offer(PoisonPill)
            return
          }
          val context = data.context
          val shuffleId = data.shuffleId
          val hostPort = context.senderAddress.hostPort
          logDebug("Handling request to send map output locations for shuffle " + shuffleId +
            " to " + hostPort)
          val shuffleStatus = shuffleStatuses.get(shuffleId).head
          context.reply(
            shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast,
              conf))
        } catch {
          case NonFatal(e) => logError(e.getMessage, e)
        }
      }
    } catch {
      case ie: InterruptedException => // exit
    }
  }
}

map地址转换

对于mapstatus和给定的partition,转换为Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],表示对于每个BlockManagerID,对应partition的blockId和size

def convertMapStatuses(
    shuffleId: Int,
    startPartition: Int,
    endPartition: Int,
    statuses: Array[MapStatus],
    mapIndex : Option[Int] = None): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
  assert (statuses != null)
  val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
  val iter = statuses.iterator.zipWithIndex
  for ((status, mapIndex) <- mapIndex.map(index => iter.filter(_._2 == index)).getOrElse(iter)) {
    if (status == null) {
      val errorMessage = s"Missing an output location for shuffle $shuffleId"
      logError(errorMessage)
      throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
    } else {
      for (part <- startPartition until endPartition) {
        val size = status.getSizeForBlock(part)
        if (size != 0) {
          splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
            ((ShuffleBlockId(shuffleId, status.mapId, part), size, mapIndex))
        }
      }
    }
  }

  splitsByAddress.iterator
}

拉取map端计算结果

参数:
spark.reducer.maxReqsInFlight这个参数真正限制了fetch请求大小和次数

spark.reducer.maxSizeInFlight :默认值:48m。
shuffle read缓冲区大小,决定了一次拉取多大的数据,一次请求最大大小为maxBytesInFlight / 5
如果可用内存比较多,可以增加参数大小,从而减少拉取次数。
spark.reducer.maxReqsInFlight : 默认值:Int.MaxValue。最大并发请求数量。
spark.reducer.maxBlocksInFlightPerAddress:默认值:Int.MaxValue。最大能拉取的block数量
spark.maxRemoteBlockSizeFetchToMem:默认值:200m。block大于这个大小会直接写入磁盘。
config.SHUFFLE_DETECT_CORRUPT
SHUFFLE_DETECT_CORRUPT_MEMORY
  • 初始化 ShuffleBlockFetcherIterator ,会执行 initialize() 方法
    1. 划分本地和远程block,返回remoteRequests = new ArrayBuffer[FetchRequest]数组,远程请求大小最大尺寸为math.max(maxBytesInFlight / 5, 1L),为了能够提供5个并发拉取的能力
    2. 将FetchRequest随机排序后存入val fetchRequests = new Queue[FetchRequest]
    3. 发送 fetch 请求直到达到 maxBytesInFlight,如果请求大小大于maxRemoteBlockSizeFetchToMem直接写入磁盘
    4. 获取本地block
打印日志
logInfo(s"Started $numFetches remote fetches in ${Utils.getUsedTimeNs(startTimeNs)}")

ShuffleBlockFetcherIterator.initialize()

private[this] def initialize(): Unit = {
  // Add a task completion callback (called in both success case and failure case) to cleanup.
  context.addTaskCompletionListener(onCompleteCallback)

  // Split local and remote blocks.
  val remoteRequests = splitLocalRemoteBlocks()
  // Add the remote requests into our queue in a random order
  fetchRequests ++= Utils.randomize(remoteRequests)
  assert ((0 == reqsInFlight) == (0 == bytesInFlight),
    "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
    ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)

  // Send out initial requests for blocks, up to our maxBytesInFlight
  fetchUpToMaxBytes()

  val numFetches = remoteRequests.size - fetchRequests.size
  logInfo(s"Started $numFetches remote fetches in ${Utils.getUsedTimeNs(startTimeNs)}")

  // Get Local Blocks
  fetchLocalBlocks()
  logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}")
}
  • 边拉取边聚合
    ShuffleBlockFetcherIterator.next()
    1. while 拉取结果队列results = new LinkedBlockingQueue[FetchResult]为null,一直fetch
    2. 发送fetch请求直到达到MaxBytes
    3. 返回(blockId,inputStream)

fetchUpToMaxBytes 方法在ShuffleBlockFetcherIterator初始化时以及每次迭代时调用,每次拉取最多spark.reducer.maxSizeInFlight大小的数据。由于之前远程获取Block时,一小部分请求可能就达到了maxBytesInFlight的限制,所以很有可能会剩余很多请求没有发送。所以每次迭代ShuffleBlockFetcher-Iterator的时候还有个附加动作用于发送剩余请求。如果一个请求比较大,会在已经没有fetch请求的时候调用,next中的while循环在没有拉取结果时会一直循环等待。如果请求大于maxRemoteBlockSizeFetchToMem会直接写入磁盘。

def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = {
  fetchReqQueue.nonEmpty &&
    (bytesInFlight == 0 ||
      (reqsInFlight + 1 <= maxReqsInFlight &&
        bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight))
}

聚合计算

和map端不同的是,reduce端一定要做聚合,写时聚合和溢写磁盘的操作和ExternalSorter一致。

  • 如果定义了聚合函数,且定义了map端聚合,那么ExternalAppendOnlyMap使用mergeCombiners作为聚合函数
  • 如果定义了聚合函数,且没有定义map端聚合,那么ExternalAppendOnlyMap使用mergeValue作为聚合函数
  • 如果没有定义聚合函数,不需要聚合直接返回迭代器

BlockStoreShuffleReader.read()

val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
  if (dep.mapSideCombine) {
    // We are reading values that are already combined
    val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
    dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
  } else {
    // We don't know the value type, but also don't care -- the dependency *should*
    // have made sure its compatible w/ this aggregator, which will convert the value
    // type to the combined type C
    val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
    dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
  }
} else {
  interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}

ExternalAppendOnlyMap.insertAll()

def insertAll(entries: Iterator[Product2[K, V]]): Unit = {
  if (currentMap == null) {
    throw new IllegalStateException(
      "Cannot insert new elements into a map after calling iterator")
  }
  // An update function for the map that we reuse across entries to avoid allocating
  // a new closure each time
  var curEntry: Product2[K, V] = null
  val update: (Boolean, C) => C = (hadVal, oldVal) => {
    if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2)
  }

  while (entries.hasNext) {
    curEntry = entries.next()
    val estimatedSize = currentMap.estimateSize()
    if (estimatedSize > _peakMemoryUsedBytes) {
      _peakMemoryUsedBytes = estimatedSize
    }
    if (maybeSpill(currentMap, estimatedSize)) {
      currentMap = new SizeTrackingAppendOnlyMap[K, C]
    }
    currentMap.changeValue(curEntry._1, update)
    addElementsRead()
  }
}

排序

之后使用ExternalSorter进行排序,这个操作和map端一致。

val sorter =
  new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)