0x00 摘要
Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文是漫谈系列的第二篇,将从源码入手,带领大家具体剖析Alink设计思想和架构为何。
因为Alink的公开资料太少,所以均为自行揣测,肯定会有疏漏错误,希望大家指出,我会随时更新。
书接上文,我们讲解迭代计算框架。
0x01 底层--迭代计算框架
这里对应如下设计原则:
-
构建一套战术打法(middleware或者adapter),即屏蔽了Flink,又可以利用好Flink,还可以让用户基于此可以快速开发算法。
-
采用最简单,最常见的开发语言和开发模式。
让我们想想看,大概有哪些基础工作需要做:
-
如何初始化
-
如何通信
-
如何分割代码,如何广播代码
-
如何分割数据,如何广播数据
-
如何迭代算法
其中最重要的概念是IterativeComQueue,这是把通信或者计算抽象成ComQueueItem,然后把ComQueueItem串联起来形成队列。这样就形成了面向迭代计算场景的一套迭代通信计算框架。
再次把目录结构列在这里:
./java/com/alibaba/alink/common:
MLEnvironment.java linalg MLEnvironmentFactory.java mapper
VectorTypes.java model comqueue utils io
里面大致有 :
-
Flink 封装模块 :MLEnvironment.java, MLEnvironmentFactory.java。
-
线性代数模块:linalg。
-
计算/通讯队列模块:comqueue,其中ComputeFunction进行计算,比如训练算法。
-
映射模块:mapper,其中Mapper进行各种映射,比如 ModelMapper 把模型映射为数值(就是转换算法)。
-
模型 :model,主要是用来读取model source。
-
基础模块:utils,io。
算法组件在其linkFrom函数中,会做如下操作:
-
先进行部分初始化,此时会调用部分Flink算子,比如groupBy等等。
-
再将算法逻辑剥离出来,委托给Mapper或者ComQueueItem。
-
Mapper或者ComQueueItem会调用Flink map算子或者mapPartition算子等。
-
调用Flink算子过程就是把算法分割然后适配到Flink上的过程。
下面就一一阐述。
1. Flink上下文封装
MLEnvironment 是个重要的类。其封装了Flink开发所必须要的运行上下文。用户可以通过这个类来获取各种实际运行环境,可以建立table,可以运行SQL语句。
/**
* The MLEnvironment stores the necessary context in Flink.
* Each MLEnvironment will be associated with a unique ID.
* The operations associated with the same MLEnvironment ID
* will share the same Flink job context.
*/
public class MLEnvironment {
private ExecutionEnvironment env;
private StreamExecutionEnvironment streamEnv;
private BatchTableEnvironment batchTableEnv;
private StreamTableEnvironment streamTableEnv;
}
2. Function
Function是计算框架中,对于计算和通讯等业务逻辑的最小模块。具体定义如下。
-
ComputeFunction 是计算模块。
-
CommunicateFunction 是通讯模块。CommunicateFunction和ComputeFunction都是ComQueueItem子类,它们是业务逻辑实现者。
-
CompareCriterionFunction 是判断模块,用来判断何时结束循环。这就允许用户指定迭代终止条件。
-
CompleteResultFunction 用来在结束循环时候调用,作为循环结果。
-
Mapper也是一种Funciton,即Mapper Function。
后续将统称为 Function。
/**
* Basic build block in {@link BaseComQueue}, for either communication or computation.
*/
public interface ComQueueItem extends Serializable {}
/**
* An BaseComQueue item for computation.
*/
public abstract class ComputeFunction implements ComQueueItem {
/**
* Perform the computation work.
*
* @param context to get input object and update output object.
*/
public abstract void calc(ComContext context);
}
/**
* An BaseComQueue item for communication.
*/
public abstract class CommunicateFunction implements ComQueueItem {
/**
* Perform communication work.
*
* @param input output of previous queue item.
* @param sessionId session id for shared objects.
* @param <T> Type of dataset.
* @return result dataset.
*/
public abstract <T> DataSet <T> communicateWith(DataSet <T> input, int sessionId);
}
结合我们代码来看,KMeansTrainBatchOp算法组件的部分作用是:KMeans算法被分割成若干CommunicateFunction。然后被添加到计算通讯队列上。
下面代码中,具体 Item 如下:
-
**ComputeFunction** :KMeansPreallocateCentroid,KMeansAssignCluster,KMeansUpdateCentroids
-
**CommunicateFunction** :AllReduce
-
**CompareCriterionFunction** :KMeansIterTermination
-
**CompleteResultFunction** : KMeansOutputModel
即算法实现的主要工作是:
-
构建了一个IterativeComQueue。
-
初始化数据,这里有两种办法:initWithPartitionedData将DataSet分片缓存至内存。initWithBroadcastData将DataSet整体缓存至每个worker的内存。
-
将计算分割为若干ComputeFunction,串联在IterativeComQueue
-
运用AllReduce通信模型完成了数据同步
static DataSet iterateICQ(...省略...) {
return new IterativeComQueue() .initWithPartitionedData(TRAIN_DATA, data) .initWithBroadcastData(INIT_CENTROID, initCentroid) .initWithBroadcastData(KMEANS_STATISTICS, statistics) .add(new KMeansPreallocateCentroid()) .add(new KMeansAssignCluster(distance)) .add(new AllReduce(CENTROID_ALL_REDUCE)) .add(new KMeansUpdateCentroids(distance)) .setCompareCriterionOfNode0(new KMeansIterTermination(distance, tol)) .closeWith(new KMeansOutputModel(distanceType, vectorColName, latitudeColName, longitudeColName)) .setMaxIter(maxIter) .exec(); }
3. 计算/通讯队列
BaseComQueue 就是这个迭代框架的基础。它维持了一个 List<ComQueueItem> queue。用户在生成算法模块时候,会把各种 Function 添加到队列中。
IterativeComQueue 是 BaseComQueue 的缺省实现,具体实现了setMaxIter,setCompareCriterionOfNode0两个函数。
BaseComQueue两个重要函数是:
-
optimize 函数:把队列上相邻的 ComputeFunction串联起来,形成一个 ChainedComputation。在框架中进行优化,就是Alink的一个优势所在。
-
exec 函数:运行队列上的各个 Function,返回最终的 Dataset。实际上,这里才真正到了 Flink,比如把计算队列上的各个 ComputeFunction 映射到 Flink 的 RichMapPartitionFunction。然后在mapPartition函数调用中,会调用真实算法逻辑片断 `computation.calc(context);`。
可以认为,BaseComQueue 是个逻辑概念,让算法工程师可以更好的组织自己的业务语言。而通过其exec函数把算法逻辑映射到Flink算子上。这样在某种程度上起到了与Flink解耦合的作用。
具体定义(摘取函数内部分代码)如下:
// Base class for the com(Computation && Communicate) queue.
public class BaseComQueue<Q extends BaseComQueue<Q>> implements Serializable {
/**
* All computation or communication functions.
*/
private final List<ComQueueItem> queue = new ArrayList<>();
/**
* The function executed to decide whether to break the loop.
*/
private CompareCriterionFunction compareCriterion;
/**
* The function executed when closing the iteration
*/
private CompleteResultFunction completeResult;
private void optimize() {
if (queue.isEmpty()) {
return;
}
int current = 0;
for (int ahead = 1; ahead < queue.size(); ++ahead) {
ComQueueItem curItem = queue.get(current);
ComQueueItem aheadItem = queue.get(ahead);
// 这里进行判断,是否是前后都是 ComputeFunction,然后合并成 ChainedComputation
if (aheadItem instanceof ComputeFunction && curItem instanceof ComputeFunction) {
if (curItem instanceof ChainedComputation) {
queue.set(current, ((ChainedComputation) curItem).add((ComputeFunction) aheadItem));
} else {
queue.set(current, new ChainedComputation()
.add((ComputeFunction) curItem)
.add((ComputeFunction) aheadItem)
);
}
} else {
queue.set(++current, aheadItem);
}
}
queue.subList(current + 1, queue.size()).clear();
}
/**
* Execute the BaseComQueue and get the result dataset.
*
* @return result dataset.
*/
public DataSet<Row> exec() {
optimize();
IterativeDataSet<byte[]> loop
= loopStartDataSet(executionEnvironment)
.iterate(maxIter);
DataSet<byte[]> input = loop
.mapPartition(new DistributeData(cacheDataObjNames, sessionId))
.withBroadcastSet(loop, "barrier")
.name("distribute data");
for (ComQueueItem com : queue) {
if ((com instanceof CommunicateFunction)) {
CommunicateFunction communication = ((CommunicateFunction) com);
// 这里会调用比如 AllReduce.communication, 其会返回allReduce包装后赋值给input,当循环遇到了下一个ComputeFunction(KMeansUpdateCentroids)时候,会把input赋给它处理。比如input = {MapPartitionOperator@5248},input.function = {AllReduce$AllReduceRecv@5260},input调用mapPartition,去间接调用KMeansUpdateCentroids。
input = communication.communicateWith(input, sessionId);
} else if (com instanceof ComputeFunction) {
final ComputeFunction computation = (ComputeFunction) com;
// 这里才到了 Flink,把计算队列上的各个 ComputeFunction 映射到 Flink 的RichMapPartitionFunction。
input = input
.mapPartition(new RichMapPartitionFunction<byte[], byte[]>() {
@Override
public void mapPartition(Iterable<byte[]> values, Collector<byte[]> out) {
ComContext context = new ComContext(
sessionId, getIterationRuntimeContext()
);
// 在这里会被Flink调用具体计算函数,就是之前算法工程师拆分的算法片段。
computation.calc(context);
}
})
.withBroadcastSet(input, "barrier")
.name(com instanceof ChainedComputation ?
((ChainedComputation) com).name()
: "computation@" + computation.getClass().getSimpleName());
} else {
throw new RuntimeException("Unsupported op in iterative queue.");
}
}
return serializeModel(clearObjs(loopEnd));
}
}
4. Mapper(Function)
Mapper是底层迭代计算框架的一部分,可以认为是 Mapper Function。因为涉及到业务逻辑,所以提前说明。
5. 初始化
初始化发生在 KMeansTrainBatchOp.linkFrom 中。我们可以看到在初始化时候,是可以调用 Flink 各种算子(比如.rebalance().map()) ,因为这时候还没有和框架相关联,这时候的计算是用户自行控制,不需要加到 IterativeComQueue 之上。
如果某一个计算既要加到 IterativeComQueue 之上,还要自己玩 Flink 算子,那框架就懵圈了,不知道该如何处理。所以用户自由操作只能发生在没有和框架联系之前。
@Override
public KMeansTrainBatchOp linkFrom(BatchOperator <?>... inputs) {
DataSet <FastDistanceVectorData> data = statistics.f0.rebalance().map(
new MapFunction <Vector, FastDistanceVectorData>() {
@Override
public FastDistanceVectorData map(Vector value) {
return distance.prepareVectorData(Row.of(value), 0);
}
});
......
}
框架也提供了初始化功能,用于将DataSet缓存到内存中,缓存的形式包括Partition和Broadcast两种形式。前者将DataSet分片缓存至内存,后者将DataSet整体缓存至每个worker的内存。
return new IterativeComQueue()
.initWithPartitionedData(TRAIN_DATA, data)
.initWithBroadcastData(INIT_CENTROID, initCentroid)
.initWithBroadcastData(KMEANS_STATISTICS, statistics)
......
6. ComputeFunction
这是算法的具体计算模块,算法工程师应该把算法拆分成各个可以并行处理的模块,分别用 ComputeFunction 实现,这样可以利用 Flnk 的分布式计算效力。
下面举出一个例子如下,这段代码为每个点(point)计算最近的聚类中心,为每个聚类中心的点坐标的计数和求和:
/**
* Find the closest cluster for every point and calculate the sums of the points belonging to the same cluster.
*/
public class KMeansAssignCluster extends ComputeFunction {
private FastDistance fastDistance;
private transient DenseMatrix distanceMatrix;
@Override
public void calc(ComContext context) {
Integer vectorSize = context.getObj(KMeansTrainBatchOp.VECTOR_SIZE);
Integer k = context.getObj(KMeansTrainBatchOp.K);
// get iterative coefficient from static memory.
Tuple2<Integer, FastDistanceMatrixData> stepNumCentroids;
if (context.getStepNo() % 2 == 0) {
stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID1);
} else {
stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID2);
}
if (null == distanceMatrix) {
distanceMatrix = new DenseMatrix(k, 1);
}
double[] sumMatrixData = context.getObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE);
if (sumMatrixData == null) {
sumMatrixData = new double[k * (vectorSize + 1)];
context.putObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE, sumMatrixData);
}
Iterable<FastDistanceVectorData> trainData = context.getObj(KMeansTrainBatchOp.TRAIN_DATA);
if (trainData == null) {
return;
}
Arrays.fill(sumMatrixData, 0.0);
for (FastDistanceVectorData sample : trainData) {
KMeansUtil.updateSumMatrix(sample, 1, stepNumCentroids.f1, vectorSize, sumMatrixData, k, fastDistance,
distanceMatrix);
}
}
}
这里能够看出,在 ComputeFunction 中,使用的是 命令式编程模式,这样能够最大的契合目前程序员现状,极高提升生产力。
7. CommunicateFunction
前面代码中有一个关键处 .add(new AllReduce(CENTROID_ALL_REDUCE))。这部分代码起到了承前启后的作用。之前的 KMeansPreallocateCentroid,KMeansAssignCluster和其后的 KMeansUpdateCentroids通过它做了一个 reduce / broadcast 通讯。
具体从注解中可以看到,AllReduce 是 MPI 相关通讯原语的一个实现。这里主要是对 double[] object 进行 reduce / broadcast。
public class AllReduce extends CommunicateFunction {
public static <T> DataSet <T> allReduce(
DataSet <T> input,
final String bufferName,
final String lengthName,
final SerializableBiConsumer <double[], double[]> op,
final int sessionId) {
final String transferBufferName = UUID.randomUUID().toString();
return input
.mapPartition(new AllReduceSend <T>(bufferName, lengthName, transferBufferName, sessionId))
.withBroadcastSet(input, "barrier")
.returns(
new TupleTypeInfo <>(Types.INT, Types.INT, PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO))
.name("AllReduceSend")
.partitionCustom(new Partitioner <Integer>() {
@Override
public int partition(Integer key, int numPartitions) {
return key;
}
}, 0)
.name("AllReduceBroadcastRaw")
.mapPartition(new AllReduceSum(bufferName, lengthName, sessionId, op))
.returns(
new TupleTypeInfo <>(Types.INT, Types.INT, PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO))
.name("AllReduceSum")
.partitionCustom(new Partitioner <Integer>() {
@Override
public int partition(Integer key, int numPartitions) {
return key;
}
}, 0)
.name("AllReduceBroadcastSum")
.mapPartition(new AllReduceRecv <T>(bufferName, lengthName, sessionId))
.returns(input.getType())
.name("AllReduceRecv");
}
}
经过调试我们能看出来,AllReduceSum 是在自己mapPartition实现中,调用了 SUM。
/**
* The all-reduce operation which does elementwise sum operation.
*/
public final static SerializableBiConsumer <double[], double[]> SUM
= new SerializableBiConsumer <double[], double[]>() {
@Override
public void accept(double[] a, double[] b) {
for (int i = 0; i < a.length; ++i) {
a[i] += b[i];
}
}
};
private static class AllReduceSum extends RichMapPartitionFunction <Tuple3 <Integer, Integer, double[]>, Tuple3 <Integer, Integer, double[]>> {
@Override
public void mapPartition(Iterable <Tuple3 <Integer, Integer, double[]>> values,
Collector <Tuple3 <Integer, Integer, double[]>> out) {
// 省略各种初始化操作,比如确定传输位置,传输目标等
......
do {
Tuple3 <Integer, Integer, double[]> val = it.next();
int localPos = val.f1 - startPos;
if (sum[localPos] == null) {
sum[localPos] = val.f2;
agg[localPos]++;
} else {
// 这里会调用 SUM
op.accept(sum[localPos], val.f2);
}
} while (it.hasNext());
for (int i = 0; i < numOfSubTasks; ++i) {
for (int j = 0; j < cnt; ++j) {
out.collect(Tuple3.of(i, startPos + j, sum[j]));
}
}
}
}
accept:129, AllReduce$3 (com.alibaba.alink.common.comqueue.communication)
accept:126, AllReduce$3 (com.alibaba.alink.common.comqueue.communication)
mapPartition:314, AllReduce$AllReduceSum (com.alibaba.alink.common.comqueue.communication)
run:103, MapPartitionDriver (org.apache.flink.runtime.operators)
run:504, BatchTask (org.apache.flink.runtime.operators)
run:157, AbstractIterativeTask (org.apache.flink.runtime.iterative.task)
run:107, IterationIntermediateTask (org.apache.flink.runtime.iterative.task)
invoke:369, BatchTask (org.apache.flink.runtime.operators)
doRun:705, Task (org.apache.flink.runtime.taskmanager)
run:530, Task (org.apache.flink.runtime.taskmanager)
run:745, Thread (java.lang)
0x02 另一种打法
总结到现在,我们发现这个迭代计算框架设计的非常优秀。但是Alink并没有限定大家只能使用这个框架来实现算法。如果你是Flink高手,你完全可以随心所欲的实现。
Alink例子中本身就有一个这样的实现 ALSExample。其核心类 AlsTrainBatchOp 就是直接使用了 Flink 算子,IterativeDataSet 等。
这就好比是武松武都头,一双戒刀搠得倒贪官佞臣,赤手空拳也打得死吊睛白额大虫。
public final class AlsTrainBatchOp
extends BatchOperator<AlsTrainBatchOp>
implements AlsTrainParams<AlsTrainBatchOp> {
@Override
public AlsTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
......
AlsTrain als = new AlsTrain(rank, numIter, lambda, implicitPrefs, alpha, numMiniBatches, nonNegative);
DataSet<Tuple3<Byte, Long, float[]>> factors = als.fit(alsInput);
DataSet<Row> output = factors.mapPartition(new RichMapPartitionFunction<Tuple3<Byte, Long, float[]>, Row>() {
@Override
public void mapPartition(Iterable<Tuple3<Byte, Long, float[]>> values, Collector<Row> out) {
new AlsModelDataConverter(userColName, itemColName).save(values, out);
}
});
return this;
}
}
多提一点,Flink ML中也有ALS算法,是一个Scala实现。没有Scala经验的算法工程师看代码会咬碎钢牙。
0x03 总结
经过这两篇文章的推测和验证,现在我们总结如下。
Alink的部分设计原则
-
算法的归算法,Flink的归Flink,尽量屏蔽AI算法和Flink之间的联系。
-
采用最简单,最常见的开发语言和思维方式。
-
尽量借鉴市面上通用的机器学习设计思路和开发模式,让开发者无缝切换。
-
构建一套战术打法(middleware或者adapter),即屏蔽了Flink,又可以利用好Flink,还可以让用户基于此可以快速开发算法。
针对这些原则,Alink实现了
-
顶层流水线,Estimator, Transformer...
-
算法组件中间层
-
底层迭代计算框架
这样Alink即可以最大限度的享受Flink带来的各种优势,也能顺应目前形势,让算法工程师工作更方便。从而达到系统性能和生产力的双重提升。
下一篇文章争取介绍 AllReduce 的具体实现。
0xFF 参考
Spark ML简介之Pipeline,DataFrame,Estimator,Transformer
斩获GitHub 2000+ Star,阿里云开源的 Alink 机器学习平台如何跑赢双11数据“博弈”?|AI 技术生态论
★★★★★★关于生活和技术的思考★★★★★★
微信公众账号:罗西的思考
如果您想及时得到个人撰写文章的消息推送,或者想看看个人推荐的技术资料,敬请关注。
本文使用 mdnice 排版