Alink漫谈(三)AllReduce通信模型 之 AllReduce

443 阅读9分钟

0x00 摘要

Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文将带领大家来分析Alink中通讯模型AllReduce的实现。

AllReduce在Alink中应用较多,比如KMeans,LDA,Word2Vec,GD,lbfgs,Newton method,owlqn,SGD,Gbdt, random forest都用到了这个通讯模型。

因为Alink的公开资料太少,所以以下均为自行揣测,肯定会有疏漏错误,希望大家指出,我会随时更新。

0x01. 示例代码

我们示例代码依然如下。

KMeansTrainBatchOp调用

 static DataSet <Row> iterateICQ(...省略...) {
  return new IterativeComQueue()
   .initWithPartitionedData(TRAIN_DATA, data)
   .initWithBroadcastData(INIT_CENTROID, initCentroid)
   .initWithBroadcastData(KMEANS_STATISTICS, statistics)
   .add(new KMeansPreallocateCentroid())
   .add(new KMeansAssignCluster(distance))
   .add(new AllReduce(CENTROID_ALL_REDUCE))
   .add(new KMeansUpdateCentroids(distance))
   .setCompareCriterionOfNode0(new KMeansIterTermination(distance, tol))
   .closeWith(new KMeansOutputModel(distanceType, vectorColName, latitudeColName, longitudeColName))
   .setMaxIter(maxIter)
   .exec();
 }

AllReduce实现

Alink的AllReduce主要代码摘取如下:

public static <T> DataSet <T> allReduce(
    return input
  .mapPartition(new AllReduceSend <T>(bufferName, lengthName, transferBufferName, sessionId))
  .withBroadcastSet(input, "barrier")
  .returns(
   new TupleTypeInfo <>(Types.INT, Types.INT, PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO))
  .name("AllReduceSend")
  .partitionCustom(new Partitioner <Integer>() {
   @Override
   public int partition(Integer key, int numPartitions) {
    return key;
   }
  }, 0)
  .name("AllReduceBroadcastRaw")
  .mapPartition(new AllReduceSum(bufferName, lengthName, sessionId, op))
  .returns(
   new TupleTypeInfo <>(Types.INT, Types.INT, PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO))
  .name("AllReduceSum")
  .partitionCustom(new Partitioner <Integer>() {
   @Override
   public int partition(Integer key, int numPartitions) {
    return key;
   }
  }, 0)
  .name("AllReduceBroadcastSum")
  .mapPartition(new AllReduceRecv <T>(bufferName, lengthName, sessionId))
  .returns(input.getType())
  .name("AllReduceRecv");
}

0x02 AllReduce实现

结合上面具体代码,我们先总结AllReduce使用流程如下

  • KMeansAssignCluster :Find the closest cluster for every point and calculate the sums of the points belonging to the same cluster。然后把自己计算出来的cluster 写入到自己 task manager 的 CENTROID_ALL_REDUCE。

  • 每个AllReduceSend 从自己task manager的CENTROID_ALL_REDUCE中取出之前存入的 cluster(每个AllReduceSend获取的cluster都是只有自己能看到的),然后发送给下游task。发送时根据 "下游task index 和 数据量" 来决定往哪些task发送。这里要注意的是:具体给哪一个task发送变量的哪一部分,是依据那个task 的 task index 和数据量 来计算出来的。这个计算机制(如何计算在代码中,也有部分作为元信息随着数据一起发送)被后面的AllReduceRecv复用。

  • 每个 AllReduceSum 接收到 AllReduceSend 发送过来的 cluster,计算求和,然后把计算结果再发送出去。每一个AllReduceSum 都是把自己计算求和出来的数据统一发给每一个下游task。

  • 每个 AllReduceRecv 都接收到 所有 AllReduceSum 发送过来的(求和之后的)cluster。存入到共享变量CENTROID_ALL_REDUCE。具体如何存就复用AllReduceSend计算机制,这样存到共享变量的什么地方就互相不会冲突。可以理解为merge操作:比如有5个AllReduce,每个AllReduce的数据都发给了5个AllReduceRecv,每个AllReduceRecv接到这5份数据之后,会根据自己的subtask index写到自己对应的state中,但是这5份数据分别写在state什么地方都是在数据元信息中指定的,彼此不会有写的冲突,这样每个AllReduceRecv就拥有了全部5份数据。

  • KMeansUpdateCentroids :取出CENTROID_ALL_REDUCE变量,然后Update the centroids based on the sum of points and point number belonging to the same cluster

1. KMeansAssignCluster

该类的作用是:为每个点(point)计算最近的聚类中心,为每个聚类中心的点坐标的计数和求和。

我们可以看出,KMeansAssignCluster 通过ComContext存储了CENTROID_ALL_REDUCE,为后续AllReduce使用。假如有5个KMeansAssignCluster,则他们计算出来的结果一般来说各不相同。虽然存储同一个变量名CENTROID_ALL_REDUCE,但是其state各不相同。

因为这5个KMeansAssignCluster势必对应了5个subtask,则其在共享变量中的<handle, taskId>必不相同,则对应不同的state,所以分开存储。

// Find the closest cluster for every point and calculate the sums of the points belonging to the same cluster.
public class KMeansAssignCluster extends ComputeFunction {
        // 存取共享变量
        double[] sumMatrixData = context.getObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE);
        if (sumMatrixData == null) {
            sumMatrixData = new double[k * (vectorSize + 1)];
            context.putObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE, sumMatrixData);
        }  
    
        for (FastDistanceVectorData sample : trainData) {
            // Find the closest centroid from centroids for sample, and add the sample to sumMatrix.
            KMeansUtil.updateSumMatrix(sample, 1, stepNumCentroids.f1, vectorSize, sumMatrixData, k, fastDistance, distanceMatrix);
        }    
}

// 程序中各个变量如下

sample = {FastDistanceVectorData@13274} 
 vector = {DenseVector@13281} "6.3 2.5 4.9 1.5"
 label = {DenseVector@13282} "72.2"
 rows = {Row[1]@13283} 

// 这个就是共享变量。4维向量 + 1 weight ---> 都是"sample和"。
sumMatrixData = {double[15]@10574} 
 0 = 23.6
 1 = 14.9
 2 = 8.7
 3 = 1.7000000000000002
 4 = 5.0
 5 = 52.400000000000006
 6 = 25.1
 7 = 39.699999999999996
 8 = 13.299999999999999
 9 = 9.0
 10 = 33.0
 11 = 16.9
 12 = 28.900000000000002
 13 = 11.4
 14 = 5.0
     
trainData = {ArrayList@10580}  size = 19
 0 = {FastDistanceVectorData@10590} 
  vector = {DenseVector@10595} "7.7 3.8 6.7 2.2"
   data = {double[4]@10601} 
    0 = 7.7
    1 = 3.8
    2 = 6.7
    3 = 2.2
  label = {DenseVector@10596} "123.46000000000001"
  rows = {Row[1]@10597} 
 1 = {FastDistanceVectorData@10603} 
  vector = {DenseVector@10623} "5.7 2.8 4.1 1.3"
  label = {DenseVector@10624} "58.83"
  rows = {Row[1]@10625} 
 2 = {FastDistanceVectorData@10604} 
 3 = {FastDistanceVectorData@10605} 
......
 17 = {FastDistanceVectorData@10619} 
 18 = {FastDistanceVectorData@10620} 
  vector = {DenseVector@10654} "6.5 3.0 5.2 2.0"
  label = {DenseVector@10655} "82.29"
  rows = {Row[1]@10656}      

2. AllReduceSend

这里需要再把代码摘录一遍,主要是因为有withBroadcastSet。其作用是:

  • 可以理解为是一个公共的共享变量,我们可以把一个dataset 数据集广播出去,然后不同的task在节点上都能够获取到,这个数据在每个节点上只会存在一份。

  • 如果不使用broadcast,则在每个节点中的每个task中都需要拷贝一份dataset数据集,比较浪费内存(也就是一个节点中可能会存在多份dataset数据)。

    return input .mapPartition(new AllReduceSend (bufferName, lengthName, transferBufferName, sessionId)) .withBroadcastSet(input, "barrier")

KMeansAssignCluster 会往上下文的变量centroidAllReduce中添加数据。所以 AllReduce 其实就是在等待这个变量。

AllReduce的第一步就是从上下文中取出共享变量,然后发送。这部分代码由AllReduceSend完成。

对于AllReduceSend的每个task来说,bufferName都是 centroidAllReduce。

因为每个AllReduceSend也对应不同的task,所以每个AllReduceSend读取的centroidAllReduce必然不一样,所以每个task获取的sendBuf都不一样。他们分别把自己<handle, taskId>对应的 "centroidAllReduce" state取出,发送给下游。

AllReduceSend 发给其下游时候,是以subtask的序号为基准发送给每一个task,即本task中获取的共享变量会发送给每一个task,但是具体给哪一个task发送变量的那一部分,是依据那个task 的 task index 和数据量 来计算出来的。如果数据量少,可能只给某一个或者几个task发送。

后续中的 taskId ,都是subtask id。

其中,如何计算给哪个task发送多少,是在DefaultDistributedInfo完成的。这里需要结合 pieces 函数进行分析。需要注意的是:AllReduceSend这么发送,AllReduceRecv后面也按照这个套路接受。这样AllReduceRecv就可以merge了。

AllReduceSend这么发送,AllReduceRecv后面也按照这个套路接受

int pieces = pieces(sendLen);//表示本人这次send的数据分成几片,比如分成50片。每片大小是TRANSFER_BUFFER_SIZE

// 将要发给 8 个 subtask
for (int i = 0; i < numOfSubTasks; ++i) {
      // 假如第5个subtask,那么它发送的起始位置就是50/8 * 4
      int startPos = (int) distributedInfo.startPos(i, numOfSubTasks, pieces);
      // 给第5个subtask发送多少片
      int cnt = (int) distributedInfo.localRowCnt(i, numOfSubTasks, pieces);

具体代码如下:

 private static int pieces(int len) {
  int div = len / TRANSFER_BUFFER_SIZE; //本人这次send的数据分成几片,每片大小是TRANSFER_BUFFER_SIZE
  int mod = len % TRANSFER_BUFFER_SIZE;

  return mod == 0 ? div : div + 1;
 }

public class DefaultDistributedInfo implements DistributedInfo {

 public long startPos(long taskId, long parallelism, long globalRowCnt) {
  long div = globalRowCnt / parallelism;
  long mod = globalRowCnt % parallelism;

  if (mod == 0) {
   return div * taskId;
  } else if (taskId >= mod) {
   return div * taskId + mod;
  } else {
   return div * taskId + taskId;
  }
 }
    
 public long localRowCnt(long taskId, long parallelism, long globalRowCnt) {
  long div = globalRowCnt / parallelism;
  long mod = globalRowCnt % parallelism;

  if (mod == 0) {
   return div;
  } else if (taskId >= mod) {
   return div;
  } else {
   return div + 1;
  }
 }     
}

具体AllReduceSend代码如下,注解中有详细说明。

// 这里是变量名字定义。 
public static final String CENTROID_ALL_REDUCE = "centroidAllReduce";

private static class AllReduceSend<T> extends RichMapPartitionFunction <T, Tuple3 <Integer, Integer, double[]>> {
        
     int numOfSubTasks = getRuntimeContext().getNumberOfParallelSubtasks();
  // 与并行度相关,每个task都会执行相同操作
  // bufferName都是 centroidAllReduce,每个task获取的sendBuf都不一样
    
        // 计算怎么发送所需要的数据结构
     int pieces = pieces(sendLen);
     DistributedInfo distributedInfo = new DefaultDistributedInfo();

        // 从上下文中获取需要传送的数据
  double[] sendBuf = context.getObj(bufferName);
        
   int agg = 0;
      // 可以看出来,是把需要传送的数据给每个task都发送。当然这个发送是根据发送数据的大小来确定的,如果数据量小,可能就只给一个或者几个task发送。
   for (int i = 0; i < numOfSubTasks; ++i) {
                // startPos : 具体发送变量的那一部分,是依据task index来决定的。
                // cnt : 具体哪一个下游 task i 发送多少数据由此决定,如果是0,就不给task i发送数据。
    int startPos = (int) distributedInfo.startPos(i, numOfSubTasks, pieces);
    int cnt = (int) distributedInfo.localRowCnt(i, numOfSubTasks, pieces);

    for (int j = 0; j < cnt; ++j) {
                    // 发送哪一个部分
     int bufStart = (startPos + j) * TRANSFER_BUFFER_SIZE;
     // the last
     if (startPos + j == pieces - 1) {
      System.arraycopy(sendBuf, bufStart, transBuf, 0, lastLen(sendLen));
     } else {
      System.arraycopy(sendBuf, bufStart, transBuf, 0, TRANSFER_BUFFER_SIZE);
     }
     agg++;
                    
          // i 是subTasks的index,startPos + j是buffer内的位置,后续分区实际就是按照这个 i 来分区的。本AllReduceSend就是发送到numOfSubTasks这些task中。
     out.collect(Tuple3.of(i, startPos + j, transBuf));
    }
   }
}

 private static int pieces(int len) {
  int div = len / TRANSFER_BUFFER_SIZE; // 4096
  int mod = len % TRANSFER_BUFFER_SIZE;
  return mod == 0 ? div : div + 1;
 }

sendBuf = {double[15]@10602} 
 0 = 40.3
 1 = 18.200000000000003
 2 = 33.6
 3 = 12.5
 4 = 6.0
 5 = 45.3
 6 = 30.599999999999998
 7 = 12.4
 8 = 2.0
 9 = 9.0
 10 = 24.0
 11 = 10.4
 12 = 17.1
 13 = 5.199999999999999
 14 = 4.0

this = {AllReduce$AllReduceSend@10598} 
 bufferName = "centroidAllReduce"
 lengthName = null
 transferBufferName = "3dfb2aae-683d-4497-91fc-30b8d6853bce"
 sessionId = 0
 runtimeContext = {AbstractIterativeTask$IterativeRuntimeUdfContext@10606}       

3. AllReduceBroadcastRaw

AllReduceSend发送变量给下游时候,使用了自定义的partition(partitionCustom )。其是用 index of subtask 来作为key分区。这样就和AllReduceSend那个out.collect对应了。

   .partitionCustom(new Partitioner <Integer>() {
    @Override
    public int partition(Integer key, int numPartitions) {
     return key;
    }
   }, 0)
   .name("AllReduceBroadcastRaw")
               
// 调用到这个partition函数的调用栈
                
partition:102, AllReduce$2 (com.alibaba.alink.common.comqueue.communication)
partition:99, AllReduce$2 (com.alibaba.alink.common.comqueue.communication)
customPartition:235, OutputEmitter (org.apache.flink.runtime.operators.shipping)
selectChannel:149, OutputEmitter (org.apache.flink.runtime.operators.shipping)
selectChannel:36, OutputEmitter (org.apache.flink.runtime.operators.shipping)
emit:120, RecordWriter (org.apache.flink.runtime.io.network.api.writer)
collect:65, OutputCollector (org.apache.flink.runtime.operators.shipping)
collect:35, CountingCollector (org.apache.flink.runtime.operators.util.metrics)
mapPartition:257, AllReduce$AllReduceSend (com.alibaba.alink.common.comqueue.communication)
run:103, MapPartitionDriver (org.apache.flink.runtime.operators)
run:504, BatchTask (org.apache.flink.runtime.operators)
run:157, AbstractIterativeTask (org.apache.flink.runtime.iterative.task)
run:107, IterationIntermediateTask (org.apache.flink.runtime.iterative.task)
invoke:369, BatchTask (org.apache.flink.runtime.operators)
doRun:705, Task (org.apache.flink.runtime.taskmanager)
run:530, Task (org.apache.flink.runtime.taskmanager)
run:745, Thread (java.lang)                  
                
                 
 // @AllReduceSend.mapPartition 这里开始调用   
 for (int i = 0; i < numOfSubTasks; ++i) {   
     // i 是subTasks的index,后续分区实际就是按照这个 i 来分区的。本AllReduceSend就是发送到numOfSubTasks这些task中。
  out.collect(Tuple3.of(i, startPos + j, transBuf));     
 }
                
 // 从后续调用序列可以看出来,最终是用 index of subtask 来作为key分区。    

// 这里发送record

 public class CountingCollector<OUT> implements Collector<OUT> {
 public void collect(OUT record) {
  this.numRecordsOut.inc();
  this.collector.collect(record);
 }     
 }
             
 record = {Tuple3@10586} "(0,0,[40.50000000000001, 18.7, 33.300000000000004, 12.8, 6.0, 29.7, 21.0, 8.4, 1.7, 6.0, 48.1, 22.199999999999996, 36.0, 12.200000000000001, 8.0, 0.0,"
 f0 = {Integer@10583} 0
 f1 = {Integer@10583} 0
 f2 = {double[4096]@10598}                
       
// 这里开始分区

public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T>> {
 private int customPartition(T record, int numberOfChannels) {
  if (extractedKeys == null) {
   extractedKeys = new Object[1];
  }

  if (comparator.extractKeys(record, extractedKeys, 0) == 1) {
            // 所以 key 是 0
   final Object key = extractedKeys[0];
   return partitioner.partition(key, numberOfChannels);
  }            
 }    
}

public final class TupleComparator<T extends Tuple> extends TupleComparatorBase<T> {
 public int extractKeys(Object record, Object[] target, int index) {
  int localIndex = index;
  for(int i = 0; i < comparators.length; i++) {
   localIndex += comparators[i].extractKeys(((Tuple) record).getField(keyPositions[i]), target, localIndex);
  }
  return localIndex - index;
 }    
}

// 就是取出第一个field的数值

key = {Integer@10583} 0
 value = 0
    
extractedKeys = {Object[1]@10587} 
 0 = {Integer@10583} 0
  value = 0

4. AllReduceSum

所有workers在它收到的数据上做reduce,然后把这个部分reduce的结果(partial results)发送给其他workers。

partial results是因为每个task接受的数据不同,是上游根据task index计算位置并且发送过来的。

但是AllReduceSum的计算结果会给每一个下游 task index 发送。

private static class AllReduceSum extends RichMapPartitionFunction <Tuple3 <Integer, Integer, double[]>, Tuple3 <Integer, Integer, double[]>> {
    
     public void mapPartition(Iterable <Tuple3 <Integer, Integer, double[]>> values,Collector <Tuple3 <Integer, Integer, double[]>> out) {
            
            // 这时候虽然也用到了context取出了sendBuf,但是只是用来获取其长度而已。
      int taskId = getRuntimeContext().getIndexOfThisSubtask();
   int numOfSubTasks = getRuntimeContext().getNumberOfParallelSubtasks();

   double[] sendBuf = context.getObj(bufferName);
   int sendLen = lengthName != null ? context.getObj(lengthName) : sendBuf.length;
   int pieces = pieces(sendLen);
   DistributedInfo distributedInfo = new DefaultDistributedInfo();

            // startPos : 本task接受的数据,startPos 是应该从原始数据的哪个位置开始。是依据task index来决定的。
            // cnt : 具体哪一个下游 task i 发送多少数据由此决定。   
   int startPos = (int) distributedInfo.startPos(taskId, numOfSubTasks, pieces);
   int cnt = (int) distributedInfo.localRowCnt(taskId, numOfSubTasks, pieces);
    
      // 这里进行了reduce SUM工作
   double[][] sum = new double[cnt][];
   double[] agg = new double[cnt];
   do {
    Tuple3 <Integer, Integer, double[]> val = it.next();
    int localPos = val.f1 - startPos;
    if (sum[localPos] == null) {
     sum[localPos] = val.f2;
     agg[localPos]++;
    } else {
     op.accept(sum[localPos], val.f2);
    }
   } while (it.hasNext());    
    
      // 依然发送给下游,依然是用subtask index来作为partition key。
            // 注意,这里是把结果发送给所有的下游task。
   for (int i = 0; i < numOfSubTasks; ++i) {
    for (int j = 0; j < cnt; ++j) {
          // startPos是本task发送的数据应该从原始数据的哪个位置开始。
          // 但是给每一个 task i 发的都是同样的数据。但是 startPos + j 很重要,下游task i 会根据这个知道它应该把接收到的数据存储在预定义变量的什么地方。
     out.collect(Tuple3.of(i, startPos + j, sum[j]));
    }
   }   
        }
}

sum = {double[1][]@10605} 
 0 = {double[4096]@10613} 
  0 = 118.50000000000001
  1 = 77.7
  2 = 37.2
  3 = 5.9
  4 = 25.0
  5 = 621.1000000000001
  6 = 284.7
  7 = 487.59999999999997
  8 = 166.5
  9 = 99.0
  10 = 136.9
  11 = 95.7
  12 = 39.0
  13 = 7.4
  14 = 26.0

5. AllReduceBroadcastSum

AllReduceSum 发送变量给下游时候,使用了自定义的partition(partitionCustom )。其是用 index of subtask 来作为key分区。

其意义和之前的 partitionCustom 相同。

6. AllReduceRecv

All workers merge partial results into final result and put it into session context with pre-defined object name.

每一个下游 AllReduceRecv 都接收到 每一个上游 AllReduceSum 发送过来的 cluster(求和之后的),然后把每份数据存入到自己task manager对应的预定义变量state的不同部分(这个不同部分是根据接受到的数据val.f1计算出来的)。

结合前面可知,AllReduceSend发送和AllReduceRecv接受,都是按照同样的套路计算在共享变量中的数据位置。这样AllReduceRecv就可以merge了。

这样就完成了所有workers把部分reduce sum的结果合并成为最终结果,然后放入预定义的上下文变量中。

 private static class AllReduceRecv<T> extends RichMapPartitionFunction <Tuple3 <Integer, Integer, double[]>, T> {
  private final String bufferName;
  private final String lengthName;
  private final int sessionId;

  @Override
  public void mapPartition(Iterable <Tuple3 <Integer, Integer, double[]>> values, Collector <T> out) throws Exception {
   ComContext context = new ComContext(sessionId, getIterationRuntimeContext());
   Iterator <Tuple3 <Integer, Integer, double[]>> it = values.iterator();
   if (!it.hasNext()) {
    return;
   }
   double[] recvBuf = context.getObj(bufferName);
   int recvLen = lengthName != null ? context.getObj(lengthName) : recvBuf.length;
   int pieces = pieces(recvLen); // 和之前AllReduceSend一样的套路计算应该存储在共享变量什么位置。
   do {
    Tuple3 <Integer, Integer, double[]> val = it.next();
    if (val.f1 == pieces - 1) {
     System.arraycopy(val.f2, 0, recvBuf, val.f1 * TRANSFER_BUFFER_SIZE, lastLen(recvLen));
    } else {
           // 拷贝到共享变量的相应部位。val.f1 是上游发送过来的。作为merge功能的起始位置。
     System.arraycopy(val.f2, 0, recvBuf, val.f1 * TRANSFER_BUFFER_SIZE, TRANSFER_BUFFER_SIZE);
    }
   } while (it.hasNext());
  }
 }

val = {Tuple3@10672} "(3,0,[335.3, 150.89999999999998, 277.5, 99.79999999999998, 50.0, 290.9, 136.3, 213.1, 67.8, 50.0, 250.3, 170.89999999999998, 73.2, 12.2, 50.0, 0.0....."
 f0 = {Integer@10682} 3
  value = 3
 f1 = {Integer@10638} 0
  value = 0
 f2 = {double[4096]@10674} 
  0 = 335.3
  1 = 150.89999999999998
  2 = 277.5
  3 = 99.79999999999998
  4 = 50.0
  5 = 290.9
  6 = 136.3
  7 = 213.1
  8 = 67.8
  9 = 50.0
  10 = 250.3
  11 = 170.89999999999998
  12 = 73.2
  13 = 12.2
  14 = 50.0
  15 = 0.0
  ......
      
// 每个task都收到了reduce sum结果。      
recvBuf = {double[15]@10666} 
 0 = 404.3
 1 = 183.1
 2 = 329.3
 3 = 117.2
 4 = 61.0
 5 = 250.3
 6 = 170.89999999999998
 7 = 73.20000000000002
 8 = 12.2
 9 = 50.0
 10 = 221.89999999999998
 11 = 104.1
 12 = 161.29999999999998
 13 = 50.4
 14 = 39.0      
      

7. KMeansUpdateCentroids

基于点计数和坐标,计算新的聚类中心。这里就是从task manager中取出了AllReduce存储的共享变量CENTROID_ALL_REDUCE。

/**
 * Update the centroids based on the sum of points and point number belonging to the same cluster.
 */
public class KMeansUpdateCentroids extends ComputeFunction {
    public void calc(ComContext context) {

        Integer vectorSize = context.getObj(KMeansTrainBatchOp.VECTOR_SIZE);
        Integer k = context.getObj(KMeansTrainBatchOp.K);

        // 这里取出AllReduce存储的共享变量
        double[] sumMatrixData = context.getObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE);

        Tuple2<Integer, FastDistanceMatrixData> stepNumCentroids;
        if (context.getStepNo() % 2 == 0) {
            stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID2);
        } else {
            stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID1);
        }

        stepNumCentroids.f0 = context.getStepNo();

        context.putObj(KMeansTrainBatchOp.K,
            updateCentroids(stepNumCentroids.f1, k, vectorSize, sumMatrixData, distance));
    }
}

0xFF 参考

我的并行计算之路(四)MPI集合通信之Reduce和Allreduce

Message Passing Interface(MPI)

Flink 之 Dataflow、Task、subTask、Operator Chains、Slot 介绍

Flink运行时之TaskManager执行Task

★★★★★★关于生活和技术的思考★★★★★★

微信公众账号:罗西的思考