从源码看机器学习平台Alink设计和架构 之 迭代计算框架

349 阅读7分钟

0x00 摘要

Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文是漫谈系列的第二篇,将从源码入手,带领大家具体剖析Alink设计思想和架构为何。

因为Alink的公开资料太少,所以均为自行揣测,肯定会有疏漏错误,希望大家指出,我会随时更新。

书接上文,我们讲解迭代计算框架。

0x01 底层--迭代计算框架

这里对应如下设计原则:

  • 构建一套战术打法(middleware或者adapter),即屏蔽了Flink,又可以利用好Flink,还可以让用户基于此可以快速开发算法。
  • 采用最简单,最常见的开发语言和开发模式。

让我们想想看,大概有哪些基础工作需要做:

  • 如何初始化
  • 如何通信
  • 如何分割代码,如何广播代码
  • 如何分割数据,如何广播数据
  • 如何迭代算法

其中最重要的概念是IterativeComQueue,这是把通信或者计算抽象成ComQueueItem,然后把ComQueueItem串联起来形成队列。这样就形成了面向迭代计算场景的一套迭代通信计算框架。

再次把目录结构列在这里:

./java/com/alibaba/alink/common:
MLEnvironment.java  linalg MLEnvironmentFactory.java mapper
VectorTypes.java  model comqueue   utils io

里面大致有 :

  • Flink 封装模块 :MLEnvironment.java, MLEnvironmentFactory.java。
  • 线性代数模块:linalg。
  • 计算/通讯队列模块:comqueue,其中ComputeFunction进行计算,比如训练算法。
  • 映射模块:mapper,其中Mapper进行各种映射,比如 ModelMapper 把模型映射为数值(就是转换算法)。
  • 模型 :model,主要是用来读取model source。
  • 基础模块:utils,io。

算法组件在其linkFrom函数中,会做如下操作:

  • 先进行部分初始化,此时会调用部分Flink算子,比如groupBy等等。
  • 再将算法逻辑剥离出来,委托给Mapper或者ComQueueItem。
  • Mapper或者ComQueueItem会调用Flink map算子或者mapPartition算子等。
  • 调用Flink算子过程就是把算法分割然后适配到Flink上的过程。

下面就一一阐述。

1. Flink上下文封装

MLEnvironment 是个重要的类。其封装了Flink开发所必须要的运行上下文。用户可以通过这个类来获取各种实际运行环境,可以建立table,可以运行SQL语句。

/**
 * The MLEnvironment stores the necessary context in Flink.
 * Each MLEnvironment will be associated with a unique ID.
 * The operations associated with the same MLEnvironment ID
 * will share the same Flink job context.
 */
public class MLEnvironment {
    private ExecutionEnvironment env;
    private StreamExecutionEnvironment streamEnv;
    private BatchTableEnvironment batchTableEnv;
    private StreamTableEnvironment streamTableEnv;
}

2. Function

Function是计算框架中,对于计算和通讯等业务逻辑的最小模块。具体定义如下。

  • ComputeFunction 是计算模块。
  • CommunicateFunction 是通讯模块。CommunicateFunction和ComputeFunction都是ComQueueItem子类,它们是业务逻辑实现者。
  • CompareCriterionFunction 是判断模块,用来判断何时结束循环。这就允许用户指定迭代终止条件。
  • CompleteResultFunction 用来在结束循环时候调用,作为循环结果。
  • Mapper也是一种Funciton,即Mapper Function。

后续将统称为 Function。

/**
 * Basic build block in {@link BaseComQueue}, for either communication or computation.
 */
public interface ComQueueItem extends Serializable {}

/**
 * An BaseComQueue item for computation.
 */
public abstract class ComputeFunction implements ComQueueItem {

 /**
  * Perform the computation work.
  *
  * @param context to get input object and update output object.
  */
 public abstract void calc(ComContext context);
}

/**
 * An BaseComQueue item for communication.
 */
public abstract class CommunicateFunction implements ComQueueItem {

    /**
     * Perform communication work.
     *
     * @param input     output of previous queue item.
     * @param sessionId session id for shared objects.
     * @param <T>       Type of dataset.
     * @return result dataset.
     */
 public abstract <T> DataSet <T> communicateWith(DataSet <T> input, int sessionId);
}

结合我们代码来看,KMeansTrainBatchOp算法组件的部分作用是:KMeans算法被分割成若干CommunicateFunction。然后被添加到计算通讯队列上。

下面代码中,具体 Item 如下:

  • **ComputeFunction** :KMeansPreallocateCentroid,KMeansAssignCluster,KMeansUpdateCentroids
  • **CommunicateFunction** :AllReduce
  • **CompareCriterionFunction** :KMeansIterTermination
  • **CompleteResultFunction** : KMeansOutputModel

即算法实现的主要工作是:

  • 构建了一个IterativeComQueue。
  • 初始化数据,这里有两种办法:initWithPartitionedData将DataSet分片缓存至内存。initWithBroadcastData将DataSet整体缓存至每个worker的内存。
  • 将计算分割为若干ComputeFunction,串联在IterativeComQueue
  • 运用AllReduce通信模型完成了数据同步

    static DataSet iterateICQ(...省略...) {

    return new IterativeComQueue() .initWithPartitionedData(TRAIN_DATA, data) .initWithBroadcastData(INIT_CENTROID, initCentroid) .initWithBroadcastData(KMEANS_STATISTICS, statistics) .add(new KMeansPreallocateCentroid()) .add(new KMeansAssignCluster(distance)) .add(new AllReduce(CENTROID_ALL_REDUCE)) .add(new KMeansUpdateCentroids(distance)) .setCompareCriterionOfNode0(new KMeansIterTermination(distance, tol)) .closeWith(new KMeansOutputModel(distanceType, vectorColName, latitudeColName, longitudeColName)) .setMaxIter(maxIter) .exec(); }

3. 计算/通讯队列

BaseComQueue 就是这个迭代框架的基础。它维持了一个 List<ComQueueItem> queue。用户在生成算法模块时候,会把各种 Function 添加到队列中。

IterativeComQueue 是 BaseComQueue 的缺省实现,具体实现了setMaxIter,setCompareCriterionOfNode0两个函数。

BaseComQueue两个重要函数是:

  • optimize 函数:把队列上相邻的 ComputeFunction串联起来,形成一个 ChainedComputation。在框架中进行优化,就是Alink的一个优势所在。
  • exec 函数:运行队列上的各个 Function,返回最终的 Dataset。实际上,这里才真正到了 Flink,比如把计算队列上的各个 ComputeFunction 映射到 Flink 的 RichMapPartitionFunction。然后在mapPartition函数调用中,会调用真实算法逻辑片断 `computation.calc(context);`。

可以认为,BaseComQueue 是个逻辑概念,让算法工程师可以更好的组织自己的业务语言。而通过其exec函数把算法逻辑映射到Flink算子上。这样在某种程度上起到了与Flink解耦合的作用。

具体定义(摘取函数内部分代码)如下:

// Base class for the com(Computation && Communicate) queue.
public class BaseComQueue<Q extends BaseComQueue<Q>> implements Serializable {

 /**
  * All computation or communication functions.
  */
 private final List<ComQueueItem> queue = new ArrayList<>();

 /**
  * The function executed to decide whether to break the loop.
  */
 private CompareCriterionFunction compareCriterion;

 /**
  * The function executed when closing the iteration
  */
 private CompleteResultFunction completeResult;    

 private void optimize() {
  if (queue.isEmpty()) {
   return;
  }

  int current = 0;
  for (int ahead = 1; ahead < queue.size(); ++ahead) {
   ComQueueItem curItem = queue.get(current);
   ComQueueItem aheadItem = queue.get(ahead);

            // 这里进行判断,是否是前后都是 ComputeFunction,然后合并成 ChainedComputation
   if (aheadItem instanceof ComputeFunction && curItem instanceof ComputeFunction) {
    if (curItem instanceof ChainedComputation) {
     queue.set(current, ((ChainedComputation) curItem).add((ComputeFunction) aheadItem));
    } else {
     queue.set(current, new ChainedComputation()
      .add((ComputeFunction) curItem)
      .add((ComputeFunction) aheadItem)
     );
    }
   } else {
    queue.set(++current, aheadItem);
   }
  }

  queue.subList(current + 1, queue.size()).clear();
 }    

 /**
  * Execute the BaseComQueue and get the result dataset.
  *
  * @return result dataset.
  */
 public DataSet<Row> exec() {

  optimize();

  IterativeDataSet<byte[]> loop
   = loopStartDataSet(executionEnvironment)
   .iterate(maxIter);

  DataSet<byte[]> input = loop
   .mapPartition(new DistributeData(cacheDataObjNames, sessionId))
   .withBroadcastSet(loop, "barrier")
   .name("distribute data");

  for (ComQueueItem com : queue) {
   if ((com instanceof CommunicateFunction)) {
    CommunicateFunction communication = ((CommunicateFunction) com);         
         // 这里会调用比如 AllReduce.communication, 其会返回allReduce包装后赋值给input,当循环遇到了下一个ComputeFunction(KMeansUpdateCentroids)时候,会把input赋给它处理。比如input = {MapPartitionOperator@5248},input.function = {AllReduce$AllReduceRecv@5260},input调用mapPartition,去间接调用KMeansUpdateCentroids。              
    input = communication.communicateWith(input, sessionId);
   } else if (com instanceof ComputeFunction) {
    final ComputeFunction computation = (ComputeFunction) com;

        // 这里才到了 Flink,把计算队列上的各个 ComputeFunction 映射到 Flink 的RichMapPartitionFunction。
    input = input
      .mapPartition(new RichMapPartitionFunction<byte[], byte[]>() {

      @Override
      public void mapPartition(Iterable<byte[]> values, Collector<byte[]> out) {
       ComContext context = new ComContext(
        sessionId, getIterationRuntimeContext()
       );
              // 在这里会被Flink调用具体计算函数,就是之前算法工程师拆分的算法片段。
       computation.calc(context);
      }
     })
     .withBroadcastSet(input, "barrier")
     .name(com instanceof ChainedComputation ?
      ((ChainedComputation) com).name()
      : "computation@" + computation.getClass().getSimpleName());
   } else {
    throw new RuntimeException("Unsupported op in iterative queue.");
   }
  }

  return serializeModel(clearObjs(loopEnd));
 }
}

4. Mapper(Function)

Mapper是底层迭代计算框架的一部分,可以认为是 Mapper Function。因为涉及到业务逻辑,所以提前说明。

5. 初始化

初始化发生在 KMeansTrainBatchOp.linkFrom 中。我们可以看到在初始化时候,是可以调用 Flink 各种算子(比如.rebalance().map()) ,因为这时候还没有和框架相关联,这时候的计算是用户自行控制,不需要加到 IterativeComQueue 之上。

如果某一个计算既要加到 IterativeComQueue 之上,还要自己玩 Flink 算子,那框架就懵圈了,不知道该如何处理。所以用户自由操作只能发生在没有和框架联系之前。

 @Override
 public KMeansTrainBatchOp linkFrom(BatchOperator <?>... inputs) {
  DataSet <FastDistanceVectorData> data = statistics.f0.rebalance().map(
   new MapFunction <Vector, FastDistanceVectorData>() {
    @Override
    public FastDistanceVectorData map(Vector value) {
     return distance.prepareVectorData(Row.of(value), 0);
    }
   });
  ......     
    }

框架也提供了初始化功能,用于将DataSet缓存到内存中,缓存的形式包括Partition和Broadcast两种形式。前者将DataSet分片缓存至内存,后者将DataSet整体缓存至每个worker的内存。

  return new IterativeComQueue()
   .initWithPartitionedData(TRAIN_DATA, data)
   .initWithBroadcastData(INIT_CENTROID, initCentroid)
   .initWithBroadcastData(KMEANS_STATISTICS, statistics)
            ......

6. ComputeFunction

这是算法的具体计算模块,算法工程师应该把算法拆分成各个可以并行处理的模块,分别用 ComputeFunction 实现,这样可以利用 Flnk 的分布式计算效力。

下面举出一个例子如下,这段代码为每个点(point)计算最近的聚类中心,为每个聚类中心的点坐标的计数和求和:

/**
 * Find the closest cluster for every point and calculate the sums of the points belonging to the same cluster.
 */
public class KMeansAssignCluster extends ComputeFunction {
    private FastDistance fastDistance;
    private transient DenseMatrix distanceMatrix;

    @Override
    public void calc(ComContext context) {
        Integer vectorSize = context.getObj(KMeansTrainBatchOp.VECTOR_SIZE);
        Integer k = context.getObj(KMeansTrainBatchOp.K);
        // get iterative coefficient from static memory.
        Tuple2<Integer, FastDistanceMatrixData> stepNumCentroids;
        if (context.getStepNo() % 2 == 0) {
            stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID1);
        } else {
            stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID2);
        }

        if (null == distanceMatrix) {
            distanceMatrix = new DenseMatrix(k, 1);
        }

        double[] sumMatrixData = context.getObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE);
        if (sumMatrixData == null) {
            sumMatrixData = new double[k * (vectorSize + 1)];
            context.putObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE, sumMatrixData);
        }

        Iterable<FastDistanceVectorData> trainData = context.getObj(KMeansTrainBatchOp.TRAIN_DATA);
        if (trainData == null) {
            return;
        }

        Arrays.fill(sumMatrixData, 0.0);
        for (FastDistanceVectorData sample : trainData) {
            KMeansUtil.updateSumMatrix(sample, 1, stepNumCentroids.f1, vectorSize, sumMatrixData, k, fastDistance,
                distanceMatrix);
        }
    }
}

这里能够看出,在 ComputeFunction 中,使用的是 命令式编程模式,这样能够最大的契合目前程序员现状,极高提升生产力。

7. CommunicateFunction

前面代码中有一个关键处 .add(new AllReduce(CENTROID_ALL_REDUCE))。这部分代码起到了承前启后的作用。之前的 KMeansPreallocateCentroid,KMeansAssignCluster和其后的 KMeansUpdateCentroids通过它做了一个 reduce / broadcast 通讯。

具体从注解中可以看到,AllReduce 是 MPI 相关通讯原语的一个实现。这里主要是对 double[] object 进行 reduce / broadcast。

public class AllReduce extends CommunicateFunction {
 public static <T> DataSet <T> allReduce(
  DataSet <T> input,
  final String bufferName,
  final String lengthName,
  final SerializableBiConsumer <double[], double[]> op,
  final int sessionId) {
  final String transferBufferName = UUID.randomUUID().toString();

  return input
   .mapPartition(new AllReduceSend <T>(bufferName, lengthName, transferBufferName, sessionId))
   .withBroadcastSet(input, "barrier")
   .returns(
    new TupleTypeInfo <>(Types.INT, Types.INT, PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO))
   .name("AllReduceSend")
   .partitionCustom(new Partitioner <Integer>() {
    @Override
    public int partition(Integer key, int numPartitions) {
     return key;
    }
   }, 0)
   .name("AllReduceBroadcastRaw")
   .mapPartition(new AllReduceSum(bufferName, lengthName, sessionId, op))
   .returns(
    new TupleTypeInfo <>(Types.INT, Types.INT, PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO))
   .name("AllReduceSum")
   .partitionCustom(new Partitioner <Integer>() {
    @Override
    public int partition(Integer key, int numPartitions) {
     return key;
    }
   }, 0)
   .name("AllReduceBroadcastSum")
   .mapPartition(new AllReduceRecv <T>(bufferName, lengthName, sessionId))
   .returns(input.getType())
   .name("AllReduceRecv");
 }    
}

经过调试我们能看出来,AllReduceSum 是在自己mapPartition实现中,调用了 SUM。

 /**
  * The all-reduce operation which does elementwise sum operation.
  */
 public final static SerializableBiConsumer <double[], double[]> SUM
  = new SerializableBiConsumer <double[], double[]>() {
  @Override
  public void accept(double[] a, double[] b) {
   for (int i = 0; i < a.length; ++i) {
    a[i] += b[i];
   }
  }
 };

private static class AllReduceSum extends RichMapPartitionFunction <Tuple3 <Integer, Integer, double[]>, Tuple3 <Integer, Integer, double[]>> {
  @Override
  public void mapPartition(Iterable <Tuple3 <Integer, Integer, double[]>> values,
         Collector <Tuple3 <Integer, Integer, double[]>> out) {

      // 省略各种初始化操作,比如确定传输位置,传输目标等
      ......

   do {
    Tuple3 <Integer, Integer, double[]> val = it.next();
    int localPos = val.f1 - startPos;
    if (sum[localPos] == null) {
     sum[localPos] = val.f2;
     agg[localPos]++;
    } else {
          // 这里会调用 SUM 
     op.accept(sum[localPos], val.f2);
    }
   } while (it.hasNext());

   for (int i = 0; i < numOfSubTasks; ++i) {
    for (int j = 0; j < cnt; ++j) {
     out.collect(Tuple3.of(i, startPos + j, sum[j]));
    }
   }
  }
 }

accept:129, AllReduce$3 (com.alibaba.alink.common.comqueue.communication)
accept:126, AllReduce$3 (com.alibaba.alink.common.comqueue.communication)
mapPartition:314, AllReduce$AllReduceSum (com.alibaba.alink.common.comqueue.communication)
run:103, MapPartitionDriver (org.apache.flink.runtime.operators)
run:504, BatchTask (org.apache.flink.runtime.operators)
run:157, AbstractIterativeTask (org.apache.flink.runtime.iterative.task)
run:107, IterationIntermediateTask (org.apache.flink.runtime.iterative.task)
invoke:369, BatchTask (org.apache.flink.runtime.operators)
doRun:705, Task (org.apache.flink.runtime.taskmanager)
run:530, Task (org.apache.flink.runtime.taskmanager)
run:745, Thread (java.lang)

0x02 另一种打法

总结到现在,我们发现这个迭代计算框架设计的非常优秀。但是Alink并没有限定大家只能使用这个框架来实现算法。如果你是Flink高手,你完全可以随心所欲的实现。

Alink例子中本身就有一个这样的实现 ALSExample。其核心类 AlsTrainBatchOp 就是直接使用了 Flink 算子,IterativeDataSet 等。

这就好比是武松武都头,一双戒刀搠得倒贪官佞臣,赤手空拳也打得死吊睛白额大虫。

public final class AlsTrainBatchOp
    extends BatchOperator<AlsTrainBatchOp>
    implements AlsTrainParams<AlsTrainBatchOp> {

    @Override
    public AlsTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
        BatchOperator<?> in = checkAndGetFirst(inputs);

   ......

        AlsTrain als = new AlsTrain(rank, numIter, lambda, implicitPrefs, alpha, numMiniBatches, nonNegative);
        DataSet<Tuple3<Byte, Long, float[]>> factors = als.fit(alsInput);

        DataSet<Row> output = factors.mapPartition(new RichMapPartitionFunction<Tuple3<Byte, Long, float[]>, Row>() {
            @Override
            public void mapPartition(Iterable<Tuple3<Byte, Long, float[]>> values, Collector<Row> out) {
                new AlsModelDataConverter(userColName, itemColName).save(values, out);
            }
        });

        return this;
    }
}

多提一点,Flink ML中也有ALS算法,是一个Scala实现。没有Scala经验的算法工程师看代码会咬碎钢牙。

0x03 总结

经过这两篇文章的推测和验证,现在我们总结如下。

Alink的部分设计原则

  • 算法的归算法,Flink的归Flink,尽量屏蔽AI算法和Flink之间的联系。

  • 采用最简单,最常见的开发语言和思维方式。

  • 尽量借鉴市面上通用的机器学习设计思路和开发模式,让开发者无缝切换。

  • 构建一套战术打法(middleware或者adapter),即屏蔽了Flink,又可以利用好Flink,还可以让用户基于此可以快速开发算法。

针对这些原则,Alink实现了

  • 顶层流水线,Estimator, Transformer...
  • 算法组件中间层
  • 底层迭代计算框架

这样Alink即可以最大限度的享受Flink带来的各种优势,也能顺应目前形势,让算法工程师工作更方便。从而达到系统性能和生产力的双重提升。

下一篇文章争取介绍 AllReduce 的具体实现。

0xFF 参考

k-means聚类算法原理简析

flink kmeans聚类算法实现

Spark ML简介之Pipeline,DataFrame,Estimator,Transformer

开源 | 全球首个批流一体机器学习平台

斩获GitHub 2000+ Star,阿里云开源的 Alink 机器学习平台如何跑赢双11数据“博弈”?|AI 技术生态论

★★★★★★关于生活和技术的思考★★★★★★

微信公众账号:罗西的思考

如果您想及时得到个人撰写文章的消息推送,或者想看看个人推荐的技术资料,敬请关注。

本文使用 mdnice 排版