Alink漫谈(三) AllReduce通信模型 之 共享

707 阅读13分钟

0x00 摘要

Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。

本文将带领大家来分析Alink中通讯模型AllReduce的实现。AllReduce在Alink中应用较多,比如KMeans,LDA,Word2Vec,GD,lbfgs,Newton method,owlqn,SGD,Gbdt, random forest都用到了这个通讯模型。

因为Alink的公开资料太少,所以以下均为自行揣测,肯定会有疏漏错误,希望大家指出,我会随时更新。

0x01 MPI是什么

MPI(Message-Passing Interface)是一个跨语言的通讯协议,用于编写并行计算,支持点对点和广播。

MPI的目标是高性能、大规模性和可移植性。MPI在今天仍为高性能计算的主要模型。

其特点是

  • A partitioned address space 每个线程只能通过调用api去读取非本地数据。所有的交互(Non-local Memory)都需要协同进行(握手)。

  • Supports only explicit parallelization 只支持显性的并行化,用户必须明确的规定消息传递的方式。

AllReduce是MPI提供的一个基本原语,我们需要先了解reduce才能更好理解AllReduce。

  • 规约函数 MPI_Reduce :规约是来自函数式编程的一个经典概念。其将通信子内各进程的同一个变量参与规约计算,并向指定的进程输出计算结果。比如通过一个函数将一批数据分成较小的一批数据。或者将一个数组的元素通过加法函数规约为一个数字。
  • 规约并广播函数 MPI_Allreduce :在计算规约的基础上,将计算结果分发到每一个进程中。比如函数在得到归约结果值之后,将结果值分发给每一个进程,这样的话,并行中的所有进程值都能知道结果值了。

MPI_Allreduce和MPI_Reduce的一个区别就是,MPI_Reduce函数将最后的结果只传给了指定的dest_process 号进程,而MPI_Allreduce函数可以将结果传递给所有的进程,因此所有的进程都能接收到结果。MPI_Allreduce函数的原型也因此不需要指定目标进程号。

0x02 Alink 实现MPI的思想

AllReduce在Alink中应用较多,比如KMeans,LDA,Word2Vec,GD,lbfgs,Newton method,owlqn,SGD,Gbdt, random forest都用到了这个通讯模型。

AllReduce在算法实现中起到了承上启下的关键作用,即把原来串行跑的并行task强制打断,把计算结果进行汇总再分发,让串行继续执行。有一点类似大家熟悉的并发中的Barrier。

对比Flink原生KMeans算法,我们能看到AllReduce对应的是 groupBy(0).reduce。只有所有数据都产生之后,才能做groupBy操作。

 DataSet<Centroid> newCentroids = points
  // compute closest centroid for each point
  .map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
  // count and sum point coordinates for each centroid
  .map(new CountAppender())
        // 这里如果是Alink,就对应了AllReduce
  .groupBy(0).reduce(new CentroidAccumulator())
  // compute new centroids from point counts and coordinate sums
  .map(new CentroidAverager());

从AllReduce的注解中我们可以清晰的看出Alink实现MPI的思想。

 * An implement of {@link CommunicateFunction} that do the AllReduce.
 *
 * AllReduce is a communication primitive widely used in MPI. In our implementation, all workers do reduce on a partition of the whole data and they all get the final reduce result.
 *
 * There're mainly three stages:
 *   1\. All workers send the there partial data to other workers for reduce.
 *   2\. All workers do reduce on all data it received and then send partial results to others.
 *   3\. All workers merge partial results into final result and put it into session context with pre-defined object name.
 */

翻译如下:

所有的workers都在部分数据上做reduce操作,所有的workers都可以获取到reduce最终结果

主要有三个阶段:
1\. 所有workers给其他workers发送需要reduce的部分数据
2\. 所有workers在它收到的数据上做reduce,然后把这个部分reduce的结果发送给其他workers
3\. 所有workers把部分reduce的结果合并成为最终结果,然后放入预定义的session 上下文变量中

"纸上得来终觉浅,绝知此事要躬行。"

Alink为了实现AllReduce,在背后做了大量的工作,下面我们一一剖析。

0x03 如何实现共享

共享是实现AllReduce的第一要务,因为在归并/广播过程中需要元数据和输入输出,如果有共享变量就可以极大简化实现。我们下面就看看Alink如何通过task manager实现共享。

1. Task相关概念

  • **Task**(任务) : Task 是一个阶段多个功能相同 subTask 的集合,类似于 Spark 中的 TaskSet。
  • **subTask**(子任务) :subTask 是 Flink 中任务最小执行单元,是一个 Java 类的实例,这个 Java 类中有属性和方法,完成具体的计算逻辑。
  • **链式优化** : 按理说应该是每个算子的一个并行度实例就是一个subtask。那么,带来很多问题,由于flink的taskmanager运行task的时候是每个task采用一个单独的线程,这就会带来很多线程切换开销,进而影响吞吐量。为了减轻这种情况,flink进行了优化,也即对subtask进行链式操作,链式操作结束之后得到的task,再作为一个调度执行单元,放到一个线程里执行。
  • **Operator Chains**(算子链) :Flink 将多个 subTask 合并成一个 Task(任务),这个过程叫做 Operator Chains,每个任务由一个线程执行。使用 Operator Chains(算子链) 可以将多个分开的 subTask 拼接成一个任务。类似于 Spark 中的 Pipeline。
  • **Slot**(插槽) :Flink 中计算资源进行隔离的单元,一个 Slot 中可以运行多个 subTask,但是这些 subTask 必须是来自同一个 application 的不同阶段的 subTask。结果就是,每个slot可以执行job的一整个pipeline。

Flink 中的程序本质上是并行的。在执行期间,每一个算子(Transformation)都有一个或多个算子subTask(Operator SubTask),每个算子的 subTask 之间都是彼此独立,并在不同的线程中执行,并且可能在不同的机器或容器上执行。

同一个application,多个不同 task的 subTask,可以运行在同一个 slot 资源槽中。同一个 task 中的多个的 subTask,不能运行在一个 slot 资源槽中,他们可以分散到其他的资源槽中。对应到后面就是:AllReduceSend的多个并行度实例都不能运行在同一个slot中。

2. TaskManager

Flink 中每一个 TaskManager 都是一个JVM进程,它可能会在独立的线程上执行一个或多个 subtask。TaskManager 相当于整个集群的 Slave 节点,负责具体的任务执行和对应任务在每个节点上的资源申请和管理。

TaskManager为了对资源进行隔离和增加允许的task数,引入了slot的概念,这个slot对资源的隔离仅仅是对内存进行隔离,策略是均分。一个 TaskManager 至少有一个 slot。如果一个TM有N个Slot,则每个Slot分配到的Memory大小为整个TM Memory的1/N,同一个TM内的Slots只有Memory隔离,CPU是共享的。

客户端通过将编写好的 Flink 应用编译打包,提交到 JobManager,然后 JobManager 会根据已注册在 JobManager 中 TaskManager 的资源情况,将任务分配给有资源的 TaskManager节点,然后启动并运行任务。

TaskManager 从 JobManager 接收需要部署的任务,然后使用 Slot 资源启动 Task,建立数据接入的网络连接,接收数据并开始数据处理。同时 TaskManager 之间的数据交互都是通过数据流的方式进行的。

Flink 的任务运行其实是采用多线程的方式,一个TaskManager(TM)在多线程中并发执行多个task。这和 MapReduce 多 JVM 进行的方式有很大的区别,Flink 能够极大提高 CPU 使用效率,在多个任务和 Task 之间通过 TaskSlot 方式共享系统资源,每个 TaskManager 中通过管理多个 TaskSlot 资源池进行对资源进行有效管理。

对应到后面就是:在一个TaskManager中间运行的多个并行的AllReduceSend实例都会共享这个TaskManager中所有静态变量。

3. 状态共享

Alink就是利用task manager的静态变量实现了变量共享。其中有几个主要类和概念比较复杂。我们从上到下进行讲解,能看到随着从上到下,需要的标示和状态逐渐增加。

3.1 概念剖析

从上往下调用层次如下:

算法角度:ComContext

用户代码调用 : context.getObj(bufferName); 这样对用户是最理想的,因为对于用户来说知道变量名字就可以经过上下文来存取。

但是ComContext则需要知道更多,比如还需要知道 自己对应的sessioin和taskID,具体下面会说明。

ComContext如此向下调用 : SessionSharedObjs.put(objName, sessionId, taskId, obj);

框架角度:IterativeComQueue

IterativeComQueue 是一个框架概念。以Kmeans为例,就是Kmeans算法对应了若干IterativeComQueue。

IterativeComQueue上拥有众多compute/communicate function,每个function都应该知道自己属于哪一个IterativeComQueue,如何和本Queue上其他function进行通信,不能和其他Queue上搞混了。这样就需要有一个概念来表标示这个Queue。于是就有了下面Session概念。

Session角度:SessionSharedObjs

为了区分每个IterativeComQueue,就产生了session这个概念。这样IterativeComQueue上所有compute/communicate function都会绑定同一个session id,同一个IterativeComQueue上的所有function之间可以通信。

一个 IterativeComQueue 对应一个session,所以<"变量名" + sessionId>就对应了这个 session 能访问的某个变量。

SessionSharedObjs 包含静态成员变量 :

  • int sessionId = 0; 递增的标示,用来区分session。
  • HashMap, Long> key2Handle。映射,表示一个session中 某个变量名 对应某个变量handle。

正常来说 "某个名字的变量" 对应 "某个变量handle" 即可。即一个session中某个变量名 对应某个变量handle。但是Flink中,会有多个subtask并行操作的状态,这样就需要有一个新的概念来标示subtask对应的变量,这个变量应该和taskId有所关联。于是就有了下面的state概念。

SessionSharedObjs向下调用 : IterTaskObjKeeper.put(handle, taskId, obj);

Subtask角度:IterTaskObjKeeper

这里就是用静态变量来实现共享。是task manager中所有的 tasks (threads)都可以访问的共享变量实例。

IterTaskObjKeeper 包含静态成员变量 :

  • long handle = 0L; 递增的标示,用来区分state。
  • Map states; 是一个映射。即handle代表哪一种变量state,表示这种变量中 "哪个task" 对应的state实例,是针对subtask的一种细分。

在Flink中,一个算法会被多个subtask并行操作。如果只有一个handle,那么多个subtask共同访问,就会有大家都熟知的各种多线程操作问题。所以Alink这里将handle拆分为多个state。从subtask角度看,每个state用<handle, taskId>来唯一标示。

总结一下,就是对于同样一个变量名字,每个subtask对应的共享state其实都是独立的,大家互不干扰。共享其实就是在这个subtask上跑的各个operator之间共享。

3.2 变量实例分析

从实际执行的变量中,我们可以有一个更加清楚的认识。

// 能看出来 session 0 中,centroidAllReduce这个变量 对应的handle是 7
SessionSharedObjs.key2Handle = {HashMap@10480}  size = 9
 {Tuple2@10492} "(initCentroid,0)" -> {Long@10493} 1
 {Tuple2@10494} "(statistics,0)" -> {Long@10495} 2
 {Tuple2@10496} "(362158a2-588b-429f-b848-c901a1e15e17,0)" -> {Long@10497} 8
 {Tuple2@10498} "(k,0)" -> {Long@10499} 6
 {Tuple2@10500} "(centroidAllReduce,0)" -> {Long@10501} 7 // 这里就是所说的
 {Tuple2@10502} "(trainData,0)" -> {Long@10503} 0
 {Tuple2@10504} "(vectorSize,0)" -> {Long@10505} 3
 {Tuple2@10506} "(centroid2,0)" -> {Long@10507} 5
 {Tuple2@10508} "(centroid1,0)" -> {Long@10509} 4

// 下面能看出来,handle 7 这一种变量,因为有 4 个subtask,所以细分为4个state。 
 com.alibaba.alink.common.comqueue.IterTaskObjKeeper.states = {HashMap@10520}  size = 36
 {Tuple2@10571} "(7,0)" -> {double[15]@10572} 
 {Tuple2@10573} "(7,1)" -> {double[15]@10574} 
 {Tuple2@10577} "(7,2)" -> {double[15]@10578} 
 {Tuple2@10581} "(7,3)" -> {double[15]@10582} 

 {Tuple2@10575} "(5,0)" -> {Tuple2@10576} "(10,com.alibaba.alink.operator.common.distance.FastDistanceMatrixData@29a72fbb)"
 {Tuple2@10579} "(5,1)" -> {Tuple2@10580} "(10,com.alibaba.alink.operator.common.distance.FastDistanceMatrixData@26c52354)"
 {Tuple2@10585} "(5,2)" -> {Tuple2@10586} "(10,com.alibaba.alink.operator.common.distance.FastDistanceMatrixData@7c6ed779)"
 {Tuple2@10588} "(5,3)" -> {Tuple2@10589} "(10,com.alibaba.alink.operator.common.distance.FastDistanceMatrixData@154b8a4d)"

下面让我们结合代码,一一解析涉及的类。

3.3 ComContext

ComContext 是最上层类,用来获取runtime信息和共享变量。IterativeComQueue(BaseComQueue )上所有的compute/communicate function都通过 ComContext 来访问共享变量。比如:

public class BaseComQueue<Q extends BaseComQueue<Q>> implements Serializable {

    // 每一个BaseComQueue都会得到唯一一个sessionId。
    private final int sessionId = SessionSharedObjs.getNewSessionId();

    int taskId = getRuntimeContext().getIndexOfThisSubtask();

    public void mapPartition(Iterable<byte[]> values, Collector<byte[]> out) {
        // 获取到了一个ComContext
        ComContext context = new ComContext(sessionId, getIterationRuntimeContext());
        if (getIterationRuntimeContext().getSuperstepNumber() == maxIter || criterion) {
            // 利用ComContext继续访问共享变量
            List<Row> model = completeResult.calc(context);
        }
    }
}

// 用户类似这么调用

double[] sendBuf = context.getObj(bufferName);

可以看出来,ComContext 就是用户应该看到的最顶层上下文概念。 taskId, sessionId 是使用关键。

  • sessionId 是在 SessionSharedObjs中定义的静态类成员变量,其会自动递增。每一个BaseComQueue都会得到唯一一个sessionId,即该Queue保持了唯一session。这样BaseComQueue中生成的ComContext都有相同的sessionId。
  • taskId是从runtime中获得。

    /**

    • Encapsulates task-specific information: name, index of subtask, parallelism and attempt number. / @Internal public class TaskInfo { /*
    • Gets the number of this parallel subtask. The numbering starts from 0 and goes up to parallelism-1 (parallelism as returned by {@link #getNumberOfParallelSubtasks()}).
    • @return The index of the parallel subtask. */ public int getIndexOfThisSubtask() { return this.indexOfSubtask; // 这里获取taskId } }

ComContext 具体类定义如下

/**
 * Context used in BaseComQueue to access basic runtime information and shared objects.
 */
public class ComContext {
 private final int taskId;
 private final int numTask;
 private final int stepNo;
 private final int sessionId;

 public ComContext(int sessionId, IterationRuntimeContext runtimeContext) {
  this.sessionId = sessionId;
  this.numTask = runtimeContext.getNumberOfParallelSubtasks();
  this.taskId = runtimeContext.getIndexOfThisSubtask();
  this.stepNo = runtimeContext.getSuperstepNumber();
 }

 /**
  * Put an object into shared objects for access of other QueueItem of the same taskId.
  *
  * @param objName object name
  * @param obj     object itself.
  */
 public void putObj(String objName, Object obj) {
  SessionSharedObjs.put(objName, sessionId, taskId, obj);
 }
}

// 比如具体举例如下
this = {ComContext@10578} 
 taskId = 4
 numTask = 8
 stepNo = 1
 sessionId = 0

3.4 SessionSharedObjs

SessionSharedObjs是再下一层的类,维护shared session objects, 这个session 共享是通过 sessionId 做到的。

SessionSharedObjs 维护了一个静态类变量 sessionId,由此区分各个Session。

SessionSharedObjs核心是 HashMap<Tuple2<String, Integer>, Long> key2Handle。即 <"变量名" + sessionId> ---> <真实变量 handle> 的一个映射。

一个 IterativeComQueue 对应一个session,所以<"变量名" + sessionId>就对应了这个 IterativeComQueue 能访问的某个变量,正常来说有一个变量handle即可。

但是因为一个 IterativeComQueue会被若干subtask并行执行,所以为了互斥和区分,所以每个handle又细分为若干state,每个state用<handle, taskId>来唯一标示。在下面会提到。

/**
 * An static class that manage shared objects for {@link BaseComQueue}s.
 */
class SessionSharedObjs implements Serializable {
 private static HashMap<Tuple2<String, Integer>, Long> key2Handle = new HashMap<>();
 private static int sessionId = 0;
 private static ReadWriteLock rwlock = new ReentrantReadWriteLock();

 /**
  * Get a new session id.
  * All access operation should bind with a session id. This id is usually shared among compute/communicate function of an {@link IterativeComQueue}.
  *
  * @return new session id.
  */
 synchronized static int getNewSessionId() {
  return sessionId++;
 }    

 static void put(String objName, int session, int taskId, Object obj) {
  rwlock.writeLock().lock();
  try {
   Long handle = key2Handle.get(Tuple2.of(objName, session));
   if (handle == null) {
    handle = IterTaskObjKeeper.getNewHandle();
    key2Handle.put(Tuple2.of(objName, session), handle);
   }
      // 这里进行调用。taskId也是辨识关键。
   IterTaskObjKeeper.put(handle, taskId, obj);
  } finally {
   rwlock.writeLock().unlock();
  }
 }    
}

3.5 IterTaskObjKeeper

这是最底层的共享类,是在task manager进程的堆内存上的一个静态实例。task manager的所有task (threads) 都可以分享。

看源码可知,IterTaskObjKeeper 是通过一个静态变量states实现了在整个JVM内共享。而具体内容是由 'handle' and 'taskId' 来共同决定。

IterTaskObjKeeper维持了 handle 递增来作为 “变量state” 的唯一种类标识。

用<handle, taskId>来作为“变量state”的唯一标识。这个就是在 task manager process 堆内存中被大家共享的变量。

即handle代表哪一种变量state,<handle, taskId>表示这种变量中,对应哪一个task的哪一个变量。 这是针对task的一种细分。

/**
 * A 'state' is an object in the heap memory of task manager process,
 * shared across all tasks (threads) in the task manager.

 * Note that the 'state' is shared by all tasks on the same task manager,
 * users should guarantee that no two tasks modify a 'state' at the same time.

 * A 'state' is identified by 'handle' and 'taskId'.
 */
public class IterTaskObjKeeper implements Serializable {
 private static Map <Tuple2 <Long, Integer>, Object> states;

 /**
  * A 'handle' is a unique identifier of a state.
  */
 private static long handle = 0L;

 private static ReadWriteLock rwlock = new ReentrantReadWriteLock();

 static {
  states = new HashMap <>();
 }

 /**
  * @note Should get a new handle on the client side and pass it to transformers.
  */
 synchronized public static long getNewHandle() {
  return handle++;
 }

 public static void put(long handle, int taskId, Object state) {
  rwlock.writeLock().lock();
  try {
   states.put(Tuple2.of(handle, taskId), state); 
  } finally {
   rwlock.writeLock().unlock();
  }
 }
}

0xFF 参考

我的并行计算之路(四)MPI集合通信之Reduce和Allreduce

Message Passing Interface(MPI)

Flink 之 Dataflow、Task、subTask、Operator Chains、Slot 介绍

Flink运行时之TaskManager执行Task

★★★★★★关于生活和技术的思考★★★★★★

微信公众账号:罗西的思考

如果您想及时得到个人撰写文章的消息推送,或者想看看个人推荐的技术资料,敬请关注。