Spark ShuffleWriter

795 阅读6分钟

Spark的shuffle系统很复杂,同时我认为这也是比MR快的最重要原因。MR是早期产物,Spark必然取其精华,弃其糟粕

三种writer

// 通过shuffleHandle来决定是哪一种ShuffleWrite
    handle match {
      case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
        new UnsafeShuffleWriter(
          env.blockManager,
          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
          context.taskMemoryManager(),
          unsafeShuffleHandle,
          mapId,
          context,
          env.conf)
      case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
        new BypassMergeSortShuffleWriter(
          env.blockManager,
          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
          bypassMergeSortHandle,
          mapId,
          context,
          env.conf)
      case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
        new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
    }

UnsafeShuffleWriter

源码中说这种方式是序列化的形式,非常高效。看看走他的条件:

  // try to buffer map outputs in a serialized form, since this is more efficient
  def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = {
    val shufId = dependency.shuffleId
    val numPartitions = dependency.partitioner.numPartitions
    if (!dependency.serializer.supportsRelocationOfSerializedObjects) {
    // 得支持序列化后寻址
      log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " +
        s"${dependency.serializer.getClass.getName}, does not support object relocation")
      false
    } else if (dependency.aggregator.isDefined) {
    // 不能有聚合逻辑!
      log.debug(
        s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined")
      false
    } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
    // 下游分区数有限制。
      log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " +
        s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions")
      false
    } else {
      log.debug(s"Can use serialized shuffle for shuffle $shufId")
      true
    }
  }

这有一个条件dependency.serializer.supportsRelocationOfSerializedObjects

# 这是源码中的伪代码,即说明了什么是supportsRelocationOfSerializedObjects
 serOut.open()
 position = 0
 serOut.write(obj1)
 serOut.flush()
 position = # of bytes written to stream so far
 obj1Bytes = output[0:position-1]
 serOut.write(obj2)
 serOut.flush()
 position2 = # of bytes written to stream so far
 obj2Bytes = output[position:position2-1]
 serIn.open([obj2bytes] concatenate [obj1bytes]) should return (obj2, obj1)

这个不知道面试有没有人问这个,用到了很多unsafe操作。记录一下细节

// 这是UnsafeShuffleWriter的write方法。
// 方法签名为:write(scala.collection.Iterator<Product2<K, V>> records) 
// 这个迭代器就是stage最后一个rdd身上的迭代器
  while (records.hasNext()) {
  // 有调用了另一个方法,把一个kv穿了进去。从方法名看就是把数据传入了缓冲区
    insertRecordIntoSorter(records.next());
  }

把一条数据record也就是kv写到字节数组中。这个字节数组默认1MB。

  void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
    assert(sorter != null);
    // 我们数据的key
    final K key = record._1();
    // 分区号
    final int partitionId = partitioner.getPartition(key);
    // 这是个byteArrayOutputStream,把里面的字节数组置空
    serBuffer.reset();
    // serOutputStream是能序列化写的流。把key序列化写进字节数组
    serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
    // 写value
    serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
    serOutputStream.flush();
    // kv一共序列化完有多大
    final int serializedRecordSize = serBuffer.size();
    assert (serializedRecordSize > 0);
    // 还有一个东西是sorter,Platform.BYTE_ARRAY_OFFSET是字节数组第一个元素的偏移量,因为java对象有对象头。
    sorter.insertRecord(
      serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
  }

把字节数组的数据copy到内存页中。如果是第一次,需要申请内存页。

  /**
   * Write a record to the shuffle sorter.
   */
  public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId)
    throws IOException {
    growPointerArrayIfNecessary();
    // Need 4 bytes to store the record length.
    // 这加4个字节也就是1个int,来存字节数组的长度
    final int required = length + 4;
    // 申请内存页。这里要跟taskManager打交道
    acquireNewPageIfNecessary(required);
    // 已经申请到了内存页,肯定!=null。内存页可以是堆也可以是堆外
    assert(currentPage != null);
    // 申请的内存页里面其实就是一个long[],这个base就是里面的long[]
    // 如果是堆外的话 这里base是null
    final Object base = currentPage.getBaseObject();
    // 这是用一个long 64位存了两个信息,一个是当前是哪个内存页,一个是内存页里要写入的偏移量
    final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
    // Platform就是unsafe,先放了一个4字节 代表record序列化后的长度
    Platform.putInt(base, pageCursor, length);
    // 因为放了4个字节,所以游标+4
    pageCursor += 4;
    // 从序列化完的那个字节数组copy到内存页中。
    Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
    pageCursor += length;
    // inMemSorter里存了索引,记录record在内存页中的信息。分区号是用来排序的
    inMemSorter.insertRecord(recordAddress, partitionId);
  }

申请内存页过程

// required从上面看,就是record序列化完的长度+4
// taskMemoryManager是每个task都有一个。
MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize,required), this);

// 默认没配置的话 是用的堆内内存
  final val tungstenMemoryMode: MemoryMode = {
    if (conf.get(MEMORY_OFFHEAP_ENABLED)) {
      require(conf.get(MEMORY_OFFHEAP_SIZE) > 0,
        "spark.memory.offHeap.size must be > 0 when spark.memory.offHeap.enabled == true")
      require(Platform.unaligned(),
        "No support for unaligned Unsafe. Set spark.memory.offHeap.enabled to false.")
      MemoryMode.OFF_HEAP
    } else {
      MemoryMode.ON_HEAP
    }
  }
// 这能看出来默认内存页是多大
  val pageSizeBytes: Long = {
    val minPageSize = 1L * 1024 * 1024   // 1MB
    val maxPageSize = 64L * minPageSize  // 64MB
    val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors()
    // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case
    val safetyFactor = 16
    // 这里默认是使用堆内存
    val maxTungstenMemory: Long = tungstenMemoryMode match {
      case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.poolSize
      case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.poolSize
    }
    val size = ByteArrayMethods.nextPowerOf2(maxTungstenMemory / cores / safetyFactor)
    val default = math.min(maxPageSize, math.max(minPageSize, size))
    conf.getSizeAsBytes("spark.buffer.pageSize", default)
  }

注意

这种writer要求shuffle不能带聚合器,像reduceByKey,groupByKey是不能用的。sortByKey可以,repartition也可以。

ByPassMergeSortWriter

先看这个类的名字,跳过了MergeSort,我理解的是跳过归并排序。走这个Writer的条件是没有开启combine并且下游分区数<200。这个的核心逻辑就是下游有n个分区就开辟n个IO,然后写出n个临时文件,之后把这些临时文件拼接成一个大文件,并且生成一个索引文件记录其中每个分区的offset。下游reduce直接拉取。

这个做法适用于groupByKey sortByKey 且下游分区数不能超过阈值(默认200)如果下游分区数太多那么产生的临时文件就太多了。

和MR对比

MR的MapTask每次输出的时候必然会做一次排序(分区有序且分区内key有序)。Spark提出这个writer,在数据量不大的情况下,不进行排序,直接向磁盘写文件。

SortShuffleWriter

这种writer是写入内存缓冲区。一旦使用内存缓冲区buffer必然要进行排序的操作,因为一个Task里面有多个分区的数据,所以buffer在spill的时候要进行排序。这样spill出的多个小文件最后走一个归并排序,就是一个整体分区有序的文件,并且同时生成一个索引文件标识这个文件中每个分区的offset。这样下游reduce端进行拉取的时候非常快。这里其实和MR处理套路基本一致


他的缓冲区有两种形式maparray

array

如果没有开启mapSideCombine,使用array作为缓冲区。这种方式是解决ByPass不满足的时候,即下游分区数过多的时候。

  // Stick values into our buffer
  while (records.hasNext) {
    addElementsRead()
    val kv = records.next()
    // 这里是 p k v
    buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
    maybeSpillCollection(usingMap = false)
  }

spill的时候会依据partition排序

map

如果开启了mapSideCombine,使用一个数组实现的HashMap来作为缓冲区。如果是map做缓冲区,由于本身就是kv的,所以很容易做聚合。

数组实现的HashMap比jdk的省空间。对象是有成本的,并且引用也是有大小的

  // Combine values in-memory first using our AppendOnlyMap
  // 这就是aggregator的三个函数
  val mergeValue = aggregator.get.mergeValue
  val createCombiner = aggregator.get.createCombiner
  var kv: Product2[K, V] = null
  // 这是一个函数。如果是同一个key,则走聚合逻辑。这就是所谓的mapSideCombine
  val update = (hadValue: Boolean, oldValue: C) => {
    if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
  }
  while (records.hasNext) {
    addElementsRead()
    kv = records.next()
    map.changeValue((getPartition(kv._1), kv._1), update)
    maybeSpillCollection(usingMap = true)
  }

spill的时候同样对partition排序。

这里提一下面试常问的Hash冲突的解决办法,挂链表解决(jdk HashMap)开放地址法(Spark这个缓冲区用的就是这个)

和MR对比

MR中默认没有combine,得人为开启。并且combine是发生在环形缓冲区spill的时候。而Spark这个是往buffer写的时候就进行了聚合,并且Spark的一些算子默认就开启了combine。所以spill的次数必然会比MR小,从而减少了IO。