Flink StreamGraph 生成源码解析

548 阅读3分钟

Transformation & StreamOperator分析

DataStream 上常见的 transformation 有 map、filter、keyBy等。 这些transformation会构造出一棵 StreamTransformation 树,通过这棵树转换成 StreamGraph。比如 DataStream.map源码如下,其中SingleOutputStreamOperator为DataStream的子类:

public <R> SingleOutputStreamOperator<R> map(
        MapFunction<T, R> mapper, TypeInformation<R> outputType) {
    return transform("Map", outputType, new StreamMap<>(clean(mapper)));
}
public <R> SingleOutputStreamOperator<R> transform(
        String operatorName,
        TypeInformation<R> outTypeInfo,
        OneInputStreamOperator<T, R> operator) {

    return doTransform(operatorName, outTypeInfo, SimpleOperatorFactory.of(operator));
}
protected <R> SingleOutputStreamOperator<R> doTransform(
        String operatorName,
        TypeInformation<R> outTypeInfo,
        StreamOperatorFactory<R> operatorFactory) {

    // read the output type of the input Transform to coax out errors about MissingTypeInfo
    transformation.getOutputType();

    OneInputTransformation<T, R> resultTransform =
            new OneInputTransformation<>(
                    //指向父 Transformation
                    this.transformation,
                    operatorName,
                    operatorFactory,
                    outTypeInfo,
                    environment.getParallelism());

    @SuppressWarnings({"unchecked", "rawtypes"})
    SingleOutputStreamOperator<R> returnStream =
            new SingleOutputStreamOperator(environment, resultTransform);

    //在env中生成transformation树,用List<Transformation<?>>存储
    getExecutionEnvironment().addOperator(resultTransform);

    return returnStream;
}

从上方代码可以了解到,map转换将用户自定义的函数MapFunction包装到StreamMap这个Operator中,再将StreamMap包装到OneInputTransformation,最后该transformation存到env中,当调用env.execute时,遍历其中的transformation集合构造出StreamGraph。Transformation封装实现如下图所示:

image.png

Flink StreamGraph 转换过程

入口为StreamExecutionEnviroment.execute()方法,在该方法中调用getStreamGraph()方法触发StreamGraph的构建。

public JobExecutionResult execute() throws Exception {
    return execute(getStreamGraph());
}
public StreamGraph getStreamGraph(boolean clearTransformations) {
   //StreamGraph构造入口 
   final StreamGraph streamGraph = getStreamGraphGenerator(transformations).generate();
    if (clearTransformations) {
        transformations.clear();
    }
    return streamGraph;
}
public StreamGraph generate() {
    
    // 其他代码省略 保留核心代码
    
    //对env中保存的
    for (Transformation<?> transformation : transformations) {
        transform(transformation);
    }
    
    // 其他代码省略 保留核心代码

    return builtStreamGraph;
}
private Collection<Integer> transform(Transformation<?> transform) {
   
    //省略部分代码

    @SuppressWarnings("unchecked")
    final TransformationTranslator<?, Transformation<?>> translator =
            (TransformationTranslator<?, Transformation<?>>)
                    translatorMap.get(transform.getClass());

    Collection<Integer> transformedIds;
    if (translator != null) {
        transformedIds = translate(translator, transform);
    } else {
        transformedIds = legacyTransform(transform);
    }

    // need this check because the iterate transformation adds itself before
    // transforming the feedback edges
    if (!alreadyTransformed.containsKey(transform)) {
        alreadyTransformed.put(transform, transformedIds);
    }

    return transformedIds;
}
private Collection<Integer> translate(
        final TransformationTranslator<?, Transformation<?>> translator,
        final Transformation<?> transform) {
    checkNotNull(translator);
    checkNotNull(transform);

    final List<Collection<Integer>> allInputIds = getParentInputIds(transform.getInputs());

    // the recursive call might have already transformed this
    if (alreadyTransformed.containsKey(transform)) {
        return alreadyTransformed.get(transform);
    }

    final String slotSharingGroup =
            determineSlotSharingGroup(
                    transform.getSlotSharingGroup().isPresent()
                            ? transform.getSlotSharingGroup().get().getName()
                            : null,
                    allInputIds.stream()
                            .flatMap(Collection::stream)
                            .collect(Collectors.toList()));

    final TransformationTranslator.Context context =
            new ContextImpl(this, streamGraph, slotSharingGroup, configuration);

    return shouldExecuteInBatchMode
            ? translator.translateForBatch(transform, context)
            : translator.translateForStreaming(transform, context);
}

不同的translation会有相对应的TransformationTranslator实现类,这个在 StreamGraphGenerator的静态代码块中初始化了,TransformationTransformationTranslator的映射关系

static {
    @SuppressWarnings("rawtypes")
    Map<Class<? extends Transformation>, TransformationTranslator<?, ? extends Transformation>>
            tmp = new HashMap<>();
    tmp.put(OneInputTransformation.class, new OneInputTransformationTranslator<>());
    tmp.put(TwoInputTransformation.class, new TwoInputTransformationTranslator<>());
    tmp.put(MultipleInputTransformation.class, new MultiInputTransformationTranslator<>());
    tmp.put(KeyedMultipleInputTransformation.class, new MultiInputTransformationTranslator<>());
    tmp.put(SourceTransformation.class, new SourceTransformationTranslator<>());
    tmp.put(SinkTransformation.class, new SinkTransformationTranslator<>());
    tmp.put(LegacySinkTransformation.class, new LegacySinkTransformationTranslator<>());
    tmp.put(LegacySourceTransformation.class, new LegacySourceTransformationTranslator<>());
    tmp.put(UnionTransformation.class, new UnionTransformationTranslator<>());
    tmp.put(PartitionTransformation.class, new PartitionTransformationTranslator<>());
    tmp.put(SideOutputTransformation.class, new SideOutputTransformationTranslator<>());
    tmp.put(ReduceTransformation.class, new ReduceTransformationTranslator<>());
    tmp.put(
            TimestampsAndWatermarksTransformation.class,
            new TimestampsAndWatermarksTransformationTranslator<>());
    tmp.put(BroadcastStateTransformation.class, new BroadcastStateTransformationTranslator<>());
    tmp.put(
            KeyedBroadcastStateTransformation.class,
            new KeyedBroadcastStateTransformationTranslator<>());
    translatorMap = Collections.unmodifiableMap(tmp);
}

下面以OneInputTransformationTranslator转换逻辑为例

protected Collection<Integer> translateInternal(
        final Transformation<OUT> transformation,
        final StreamOperatorFactory<OUT> operatorFactory,
        final TypeInformation<IN> inputType,
        @Nullable final KeySelector<IN, ?> stateKeySelector,
        @Nullable final TypeInformation<?> stateKeyType,
        final Context context) {
   
    // 其他代码省略 保留核心代码

    streamGraph.addOperator(
            transformationId,
            slotSharingGroup,
            transformation.getCoLocationGroupKey(),
            operatorFactory,
            inputType,
            transformation.getOutputType(),
            transformation.getName());

    //根据依赖关系,获取当前transformation的所有父transformation
    final List<Transformation<?>> parentTransformations = transformation.getInputs();
    
    // 省略部分代码

    for (Integer inputId : context.getStreamNodeIds(parentTransformations.get(0))) {
        streamGraph.addEdge(inputId, transformationId, 0);
    }

    return Collections.singleton(transformationId);
}

转换过程可以总结为:首先会对该transform的上游transform进行递归转换,确保上游的都已经完成了转化。然后通过transform构造出StreamNode,最后与上游的transform进行连接,构造出StreamNode

再来看下对逻辑转换(partition)的处理,如下是PartitionTransformationTranslator函数的源码:

private Collection<Integer> translateInternal(
        final PartitionTransformation<OUT> transformation, final Context context) {
    
    //省略了部分代码

    final Transformation<?> input = parentTransformations.get(0);
    List<Integer> resultIds = new ArrayList<>();

    for (Integer inputId : context.getStreamNodeIds(input)) {
        final int virtualId = Transformation.getNewNodeId();
        streamGraph.addVirtualPartitionNode(
                inputId,
                virtualId,
                transformation.getPartitioner(),
                transformation.getExchangeMode());
        resultIds.add(virtualId);
    }
    return resultIds;
}

从中可以看出对partition的转换没有生成具体的StreamNodeStreamEdge,而是添加一个虚节点。当partition的下游transform(如map)添加edge时(调用StreamGraph.addEdge),会把partition信息写入到edge中。

private void addEdgeInternal(
        Integer upStreamVertexID,
        Integer downStreamVertexID,
        int typeNumber,
        StreamPartitioner<?> partitioner,
        List<String> outputNames,
        OutputTag outputTag,
        StreamExchangeMode exchangeMode) {

    //侧输出流节点
    if (virtualSideOutputNodes.containsKey(upStreamVertexID)) {
        int virtualId = upStreamVertexID;
        upStreamVertexID = virtualSideOutputNodes.get(virtualId).f0;
        if (outputTag == null) {
            outputTag = virtualSideOutputNodes.get(virtualId).f1;
        }
        addEdgeInternal(
                upStreamVertexID,
                downStreamVertexID,
                typeNumber,
                partitioner,
                null,
                outputTag,
                exchangeMode);
      // 上游是分区节点
    } else if (virtualPartitionNodes.containsKey(upStreamVertexID)) {
        int virtualId = upStreamVertexID;
        upStreamVertexID = virtualPartitionNodes.get(virtualId).f0;
        if (partitioner == null) {
            partitioner = virtualPartitionNodes.get(virtualId).f1;
        }
        exchangeMode = virtualPartitionNodes.get(virtualId).f2;
        
        //递归找到上游非虚拟节点,并把虚拟节点分区信息写入Edge中
        addEdgeInternal(
                upStreamVertexID,
                downStreamVertexID,
                typeNumber,
                partitioner,
                outputNames,
                outputTag,
                exchangeMode);
      //真正构建StreamEdge
    } else {
        StreamNode upstreamNode = getStreamNode(upStreamVertexID);
        StreamNode downstreamNode = getStreamNode(downStreamVertexID);

        // If no partitioner was specified and the parallelism of upstream and downstream
        // operator matches use forward partitioning, use rebalance otherwise.
        if (partitioner == null
                && upstreamNode.getParallelism() == downstreamNode.getParallelism()) {
            partitioner = new ForwardPartitioner<Object>();
        } else if (partitioner == null) {
            partitioner = new RebalancePartitioner<Object>();
        }

        if (partitioner instanceof ForwardPartitioner) {
            if (upstreamNode.getParallelism() != downstreamNode.getParallelism()) {
                throw new UnsupportedOperationException(
                        "Forward partitioning does not allow "
                                + "change of parallelism. Upstream operation: "
                                + upstreamNode
                                + " parallelism: "
                                + upstreamNode.getParallelism()
                                + ", downstream operation: "
                                + downstreamNode
                                + " parallelism: "
                                + downstreamNode.getParallelism()
                                + " You must use another partitioning strategy, such as broadcast, rebalance, shuffle or global.");
            }
        }

        if (exchangeMode == null) {
            exchangeMode = StreamExchangeMode.UNDEFINED;
        }

        StreamEdge edge =
                new StreamEdge(
                        upstreamNode,
                        downstreamNode,
                        typeNumber,
                        partitioner,
                        outputTag,
                        exchangeMode);

        // StreamNode 和 StreamEdge 串联起来
        getStreamNode(edge.getSourceId()).addOutEdge(edge);
        getStreamNode(edge.getTargetId()).addInEdge(edge);
    }
}

备注:在Flink 1.11版本后,有一个新的部署方式(run application),在这种模式下StreamGraph 也是在集群生成,所以现在说StreamGraph在Client中生成并不准确