0x00 摘要
Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。
本文将带领大家来分析Alink中通讯模型AllReduce的实现。AllReduce在Alink中应用较多,比如KMeans,LDA,Word2Vec,GD,lbfgs,Newton method,owlqn,SGD,Gbdt, random forest都用到了这个通讯模型。
因为Alink的公开资料太少,所以以下均为自行揣测,肯定会有疏漏错误,希望大家指出,我会随时更新。
0x01 MPI是什么
MPI(Message-Passing Interface)是一个跨语言的通讯协议,用于编写并行计算,支持点对点和广播。
MPI的目标是高性能、大规模性和可移植性。MPI在今天仍为高性能计算的主要模型。
其特点是
-
A partitioned address space 每个线程只能通过调用api去读取非本地数据。所有的交互(Non-local Memory)都需要协同进行(握手)。
-
Supports only explicit parallelization 只支持显性的并行化,用户必须明确的规定消息传递的方式。
AllReduce是MPI提供的一个基本原语,我们需要先了解reduce才能更好理解AllReduce。
-
规约函数 MPI_Reduce :规约是来自函数式编程的一个经典概念。其将通信子内各进程的同一个变量参与规约计算,并向指定的进程输出计算结果。比如通过一个函数将一批数据分成较小的一批数据。或者将一个数组的元素通过加法函数规约为一个数字。
-
规约并广播函数 MPI_Allreduce :在计算规约的基础上,将计算结果分发到每一个进程中。比如函数在得到归约结果值之后,将结果值分发给每一个进程,这样的话,并行中的所有进程值都能知道结果值了。
MPI_Allreduce和MPI_Reduce的一个区别就是,MPI_Reduce函数将最后的结果只传给了指定的dest_process 号进程,而MPI_Allreduce函数可以将结果传递给所有的进程,因此所有的进程都能接收到结果。MPI_Allreduce函数的原型也因此不需要指定目标进程号。
0x02 Alink 实现MPI的思想
AllReduce在Alink中应用较多,比如KMeans,LDA,Word2Vec,GD,lbfgs,Newton method,owlqn,SGD,Gbdt, random forest都用到了这个通讯模型。
AllReduce在算法实现中起到了承上启下的关键作用,即把原来串行跑的并行task强制打断,把计算结果进行汇总再分发,让串行继续执行。有一点类似大家熟悉的并发中的Barrier。
对比Flink原生KMeans算法,我们能看到AllReduce对应的是 groupBy(0).reduce。只有所有数据都产生之后,才能做groupBy操作。
DataSet<Centroid> newCentroids = points
// compute closest centroid for each point
.map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
// count and sum point coordinates for each centroid
.map(new CountAppender())
// 这里如果是Alink,就对应了AllReduce
.groupBy(0).reduce(new CentroidAccumulator())
// compute new centroids from point counts and coordinate sums
.map(new CentroidAverager());
从AllReduce的注解中我们可以清晰的看出Alink实现MPI的思想。
* An implement of {@link CommunicateFunction} that do the AllReduce.
*
* AllReduce is a communication primitive widely used in MPI. In our implementation, all workers do reduce on a partition of the whole data and they all get the final reduce result.
*
* There're mainly three stages:
* 1\. All workers send the there partial data to other workers for reduce.
* 2\. All workers do reduce on all data it received and then send partial results to others.
* 3\. All workers merge partial results into final result and put it into session context with pre-defined object name.
*/
翻译如下:
所有的workers都在部分数据上做reduce操作,所有的workers都可以获取到reduce最终结果
主要有三个阶段:
1\. 所有workers给其他workers发送需要reduce的部分数据
2\. 所有workers在它收到的数据上做reduce,然后把这个部分reduce的结果发送给其他workers
3\. 所有workers把部分reduce的结果合并成为最终结果,然后放入预定义的session 上下文变量中
"纸上得来终觉浅,绝知此事要躬行。"
Alink为了实现AllReduce,在背后做了大量的工作,下面我们一一剖析。
0x03 如何实现共享
共享是实现AllReduce的第一要务,因为在归并/广播过程中需要元数据和输入输出,如果有共享变量就可以极大简化实现。我们下面就看看Alink如何通过task manager实现共享。
1. Task相关概念
-
**Task**(任务) : Task 是一个阶段多个功能相同 subTask 的集合,类似于 Spark 中的 TaskSet。
-
**subTask**(子任务) :subTask 是 Flink 中任务最小执行单元,是一个 Java 类的实例,这个 Java 类中有属性和方法,完成具体的计算逻辑。
-
**链式优化** : 按理说应该是每个算子的一个并行度实例就是一个subtask。那么,带来很多问题,由于flink的taskmanager运行task的时候是每个task采用一个单独的线程,这就会带来很多线程切换开销,进而影响吞吐量。为了减轻这种情况,flink进行了优化,也即对subtask进行链式操作,链式操作结束之后得到的task,再作为一个调度执行单元,放到一个线程里执行。
-
**Operator Chains**(算子链) :Flink 将多个 subTask 合并成一个 Task(任务),这个过程叫做 Operator Chains,每个任务由一个线程执行。使用 Operator Chains(算子链) 可以将多个分开的 subTask 拼接成一个任务。类似于 Spark 中的 Pipeline。
-
**Slot**(插槽) :Flink 中计算资源进行隔离的单元,一个 Slot 中可以运行多个 subTask,但是这些 subTask 必须是来自同一个 application 的不同阶段的 subTask。结果就是,每个slot可以执行job的一整个pipeline。
Flink 中的程序本质上是并行的。在执行期间,每一个算子(Transformation)都有一个或多个算子subTask(Operator SubTask),每个算子的 subTask 之间都是彼此独立,并在不同的线程中执行,并且可能在不同的机器或容器上执行。
同一个application,多个不同 task的 subTask,可以运行在同一个 slot 资源槽中。同一个 task 中的多个的 subTask,不能运行在一个 slot 资源槽中,他们可以分散到其他的资源槽中。对应到后面就是:AllReduceSend的多个并行度实例都不能运行在同一个slot中。
2. TaskManager
Flink 中每一个 TaskManager 都是一个JVM进程,它可能会在独立的线程上执行一个或多个 subtask。TaskManager 相当于整个集群的 Slave 节点,负责具体的任务执行和对应任务在每个节点上的资源申请和管理。
TaskManager为了对资源进行隔离和增加允许的task数,引入了slot的概念,这个slot对资源的隔离仅仅是对内存进行隔离,策略是均分。一个 TaskManager 至少有一个 slot。如果一个TM有N个Slot,则每个Slot分配到的Memory大小为整个TM Memory的1/N,同一个TM内的Slots只有Memory隔离,CPU是共享的。
客户端通过将编写好的 Flink 应用编译打包,提交到 JobManager,然后 JobManager 会根据已注册在 JobManager 中 TaskManager 的资源情况,将任务分配给有资源的 TaskManager节点,然后启动并运行任务。
TaskManager 从 JobManager 接收需要部署的任务,然后使用 Slot 资源启动 Task,建立数据接入的网络连接,接收数据并开始数据处理。同时 TaskManager 之间的数据交互都是通过数据流的方式进行的。
Flink 的任务运行其实是采用多线程的方式,一个TaskManager(TM)在多线程中并发执行多个task。这和 MapReduce 多 JVM 进行的方式有很大的区别,Flink 能够极大提高 CPU 使用效率,在多个任务和 Task 之间通过 TaskSlot 方式共享系统资源,每个 TaskManager 中通过管理多个 TaskSlot 资源池进行对资源进行有效管理。
对应到后面就是:在一个TaskManager中间运行的多个并行的AllReduceSend实例都会共享这个TaskManager中所有静态变量。
3. 状态共享
Alink就是利用task manager的静态变量实现了变量共享。其中有几个主要类和概念比较复杂。我们从上到下进行讲解,能看到随着从上到下,需要的标示和状态逐渐增加。
3.1 概念剖析
从上往下调用层次如下:
算法角度:ComContext
用户代码调用 : context.getObj(bufferName); 这样对用户是最理想的,因为对于用户来说知道变量名字就可以经过上下文来存取。
但是ComContext则需要知道更多,比如还需要知道 自己对应的sessioin和taskID,具体下面会说明。
ComContext如此向下调用 : SessionSharedObjs.put(objName, sessionId, taskId, obj);
框架角度:IterativeComQueue
IterativeComQueue 是一个框架概念。以Kmeans为例,就是Kmeans算法对应了若干IterativeComQueue。
IterativeComQueue上拥有众多compute/communicate function,每个function都应该知道自己属于哪一个IterativeComQueue,如何和本Queue上其他function进行通信,不能和其他Queue上搞混了。这样就需要有一个概念来表标示这个Queue。于是就有了下面Session概念。
Session角度:SessionSharedObjs
为了区分每个IterativeComQueue,就产生了session这个概念。这样IterativeComQueue上所有compute/communicate function都会绑定同一个session id,同一个IterativeComQueue上的所有function之间可以通信。
一个 IterativeComQueue 对应一个session,所以<"变量名" + sessionId>就对应了这个 session 能访问的某个变量。
SessionSharedObjs 包含静态成员变量 :
-
int sessionId = 0; 递增的标示,用来区分session。
-
HashMap, Long> key2Handle。映射,表示一个session中 某个变量名 对应某个变量handle。
正常来说 "某个名字的变量" 对应 "某个变量handle" 即可。即一个session中某个变量名 对应某个变量handle。但是Flink中,会有多个subtask并行操作的状态,这样就需要有一个新的概念来标示subtask对应的变量,这个变量应该和taskId有所关联。于是就有了下面的state概念。
SessionSharedObjs向下调用 : IterTaskObjKeeper.put(handle, taskId, obj);
Subtask角度:IterTaskObjKeeper
这里就是用静态变量来实现共享。是task manager中所有的 tasks (threads)都可以访问的共享变量实例。
IterTaskObjKeeper 包含静态成员变量 :
-
long handle = 0L; 递增的标示,用来区分state。
-
Map states; 是一个映射。即handle代表哪一种变量state,表示这种变量中 "哪个task" 对应的state实例,是针对subtask的一种细分。
在Flink中,一个算法会被多个subtask并行操作。如果只有一个handle,那么多个subtask共同访问,就会有大家都熟知的各种多线程操作问题。所以Alink这里将handle拆分为多个state。从subtask角度看,每个state用<handle, taskId>来唯一标示。
总结一下,就是对于同样一个变量名字,每个subtask对应的共享state其实都是独立的,大家互不干扰。共享其实就是在这个subtask上跑的各个operator之间共享。
3.2 变量实例分析
从实际执行的变量中,我们可以有一个更加清楚的认识。
// 能看出来 session 0 中,centroidAllReduce这个变量 对应的handle是 7
SessionSharedObjs.key2Handle = {HashMap@10480} size = 9
{Tuple2@10492} "(initCentroid,0)" -> {Long@10493} 1
{Tuple2@10494} "(statistics,0)" -> {Long@10495} 2
{Tuple2@10496} "(362158a2-588b-429f-b848-c901a1e15e17,0)" -> {Long@10497} 8
{Tuple2@10498} "(k,0)" -> {Long@10499} 6
{Tuple2@10500} "(centroidAllReduce,0)" -> {Long@10501} 7 // 这里就是所说的
{Tuple2@10502} "(trainData,0)" -> {Long@10503} 0
{Tuple2@10504} "(vectorSize,0)" -> {Long@10505} 3
{Tuple2@10506} "(centroid2,0)" -> {Long@10507} 5
{Tuple2@10508} "(centroid1,0)" -> {Long@10509} 4
// 下面能看出来,handle 7 这一种变量,因为有 4 个subtask,所以细分为4个state。
com.alibaba.alink.common.comqueue.IterTaskObjKeeper.states = {HashMap@10520} size = 36
{Tuple2@10571} "(7,0)" -> {double[15]@10572}
{Tuple2@10573} "(7,1)" -> {double[15]@10574}
{Tuple2@10577} "(7,2)" -> {double[15]@10578}
{Tuple2@10581} "(7,3)" -> {double[15]@10582}
{Tuple2@10575} "(5,0)" -> {Tuple2@10576} "(10,com.alibaba.alink.operator.common.distance.FastDistanceMatrixData@29a72fbb)"
{Tuple2@10579} "(5,1)" -> {Tuple2@10580} "(10,com.alibaba.alink.operator.common.distance.FastDistanceMatrixData@26c52354)"
{Tuple2@10585} "(5,2)" -> {Tuple2@10586} "(10,com.alibaba.alink.operator.common.distance.FastDistanceMatrixData@7c6ed779)"
{Tuple2@10588} "(5,3)" -> {Tuple2@10589} "(10,com.alibaba.alink.operator.common.distance.FastDistanceMatrixData@154b8a4d)"
下面让我们结合代码,一一解析涉及的类。
3.3 ComContext
ComContext 是最上层类,用来获取runtime信息和共享变量。IterativeComQueue(BaseComQueue )上所有的compute/communicate function都通过 ComContext 来访问共享变量。比如:
public class BaseComQueue<Q extends BaseComQueue<Q>> implements Serializable {
// 每一个BaseComQueue都会得到唯一一个sessionId。
private final int sessionId = SessionSharedObjs.getNewSessionId();
int taskId = getRuntimeContext().getIndexOfThisSubtask();
public void mapPartition(Iterable<byte[]> values, Collector<byte[]> out) {
// 获取到了一个ComContext
ComContext context = new ComContext(sessionId, getIterationRuntimeContext());
if (getIterationRuntimeContext().getSuperstepNumber() == maxIter || criterion) {
// 利用ComContext继续访问共享变量
List<Row> model = completeResult.calc(context);
}
}
}
// 用户类似这么调用
double[] sendBuf = context.getObj(bufferName);
可以看出来,ComContext 就是用户应该看到的最顶层上下文概念。 taskId, sessionId 是使用关键。
-
sessionId 是在 SessionSharedObjs中定义的静态类成员变量,其会自动递增。每一个BaseComQueue都会得到唯一一个sessionId,即该Queue保持了唯一session。这样BaseComQueue中生成的ComContext都有相同的sessionId。
-
taskId是从runtime中获得。
/**
- Encapsulates task-specific information: name, index of subtask, parallelism and attempt number. / @Internal public class TaskInfo { /*
- Gets the number of this parallel subtask. The numbering starts from 0 and goes up to parallelism-1 (parallelism as returned by {@link #getNumberOfParallelSubtasks()}).
- @return The index of the parallel subtask. */ public int getIndexOfThisSubtask() { return this.indexOfSubtask; // 这里获取taskId } }
ComContext 具体类定义如下
/**
* Context used in BaseComQueue to access basic runtime information and shared objects.
*/
public class ComContext {
private final int taskId;
private final int numTask;
private final int stepNo;
private final int sessionId;
public ComContext(int sessionId, IterationRuntimeContext runtimeContext) {
this.sessionId = sessionId;
this.numTask = runtimeContext.getNumberOfParallelSubtasks();
this.taskId = runtimeContext.getIndexOfThisSubtask();
this.stepNo = runtimeContext.getSuperstepNumber();
}
/**
* Put an object into shared objects for access of other QueueItem of the same taskId.
*
* @param objName object name
* @param obj object itself.
*/
public void putObj(String objName, Object obj) {
SessionSharedObjs.put(objName, sessionId, taskId, obj);
}
}
// 比如具体举例如下
this = {ComContext@10578}
taskId = 4
numTask = 8
stepNo = 1
sessionId = 0
3.4 SessionSharedObjs
SessionSharedObjs是再下一层的类,维护shared session objects, 这个session 共享是通过 sessionId 做到的。
SessionSharedObjs 维护了一个静态类变量 sessionId,由此区分各个Session。
SessionSharedObjs核心是 HashMap<Tuple2<String, Integer>, Long> key2Handle。即 <"变量名" + sessionId> ---> <真实变量 handle> 的一个映射。
一个 IterativeComQueue 对应一个session,所以<"变量名" + sessionId>就对应了这个 IterativeComQueue 能访问的某个变量,正常来说有一个变量handle即可。
但是因为一个 IterativeComQueue会被若干subtask并行执行,所以为了互斥和区分,所以每个handle又细分为若干state,每个state用<handle, taskId>来唯一标示。在下面会提到。
/**
* An static class that manage shared objects for {@link BaseComQueue}s.
*/
class SessionSharedObjs implements Serializable {
private static HashMap<Tuple2<String, Integer>, Long> key2Handle = new HashMap<>();
private static int sessionId = 0;
private static ReadWriteLock rwlock = new ReentrantReadWriteLock();
/**
* Get a new session id.
* All access operation should bind with a session id. This id is usually shared among compute/communicate function of an {@link IterativeComQueue}.
*
* @return new session id.
*/
synchronized static int getNewSessionId() {
return sessionId++;
}
static void put(String objName, int session, int taskId, Object obj) {
rwlock.writeLock().lock();
try {
Long handle = key2Handle.get(Tuple2.of(objName, session));
if (handle == null) {
handle = IterTaskObjKeeper.getNewHandle();
key2Handle.put(Tuple2.of(objName, session), handle);
}
// 这里进行调用。taskId也是辨识关键。
IterTaskObjKeeper.put(handle, taskId, obj);
} finally {
rwlock.writeLock().unlock();
}
}
}
3.5 IterTaskObjKeeper
这是最底层的共享类,是在task manager进程的堆内存上的一个静态实例。task manager的所有task (threads) 都可以分享。
看源码可知,IterTaskObjKeeper 是通过一个静态变量states实现了在整个JVM内共享。而具体内容是由 'handle' and 'taskId' 来共同决定。
IterTaskObjKeeper维持了 handle 递增来作为 “变量state” 的唯一种类标识。
用<handle, taskId>来作为“变量state”的唯一标识。这个就是在 task manager process 堆内存中被大家共享的变量。
即handle代表哪一种变量state,<handle, taskId>表示这种变量中,对应哪一个task的哪一个变量。 这是针对task的一种细分。
/**
* A 'state' is an object in the heap memory of task manager process,
* shared across all tasks (threads) in the task manager.
* Note that the 'state' is shared by all tasks on the same task manager,
* users should guarantee that no two tasks modify a 'state' at the same time.
* A 'state' is identified by 'handle' and 'taskId'.
*/
public class IterTaskObjKeeper implements Serializable {
private static Map <Tuple2 <Long, Integer>, Object> states;
/**
* A 'handle' is a unique identifier of a state.
*/
private static long handle = 0L;
private static ReadWriteLock rwlock = new ReentrantReadWriteLock();
static {
states = new HashMap <>();
}
/**
* @note Should get a new handle on the client side and pass it to transformers.
*/
synchronized public static long getNewHandle() {
return handle++;
}
public static void put(long handle, int taskId, Object state) {
rwlock.writeLock().lock();
try {
states.put(Tuple2.of(handle, taskId), state);
} finally {
rwlock.writeLock().unlock();
}
}
}
0xFF 参考
我的并行计算之路(四)MPI集合通信之Reduce和Allreduce
Message Passing Interface(MPI)
Flink 之 Dataflow、Task、subTask、Operator Chains、Slot 介绍
★★★★★★关于生活和技术的思考★★★★★★
微信公众账号:罗西的思考
如果您想及时得到个人撰写文章的消息推送,或者想看看个人推荐的技术资料,敬请关注。