Spring Ai Alibaba Graph源码解读系列—核心启动类

292 阅读28分钟

原文链接地址:Spring Ai Alibaba Graph源码解读系列—核心启动类

[!TIP] 取自 Spring Ai Alibaba 1.0.0.3 版本

核心类

OverAllState

图执行过程中的核心状态管理类,用于存储和管理图中各个节点间共享的数据

字段名称
字段类型
描述
data
Map
存储实际的状态数据
keyStrategies
Map
每个键对应的更新策略
resume
Boolean
标识状态是否用于恢复执行
humanFeedback
HumanFeedback
人类反馈消息
interruptMessage
String
存储中断消息
DEFAULTINPUTKEY
String
静态字符串常量“input”,表示输入键名

对外暴露的方法


方法名称
描述
构造方法

OverAllState
支持四种方式构造
- 无参: 默认构造函数,初始化空数据和策略映射,并注册默认输入键
- Map data): 从现有数据初始化状态
- (boolean resume): 指定恢复模式的构造函数
- (Map data, Map keyStrategies, Boolean resume): 完整构造函数
数据访问方法
data
返回不可变的数据映射视图
value
- (String key): 获取指定键的值(Optional包装)
- (String key, Class type): 获取指定键的值并转换为指定类型
- (String key, T defaultValue): 获取指定键的值,如果不存在则返回默认值
keyStrategies
获取键策略映射
状态更新方法
input
处理输入数据并更新状态
updateState
使用已注册的策略更新状态
updateStateBySchema
根据模式定义的策略更新状态
恢复模式相关方法
isResume
检查是否处于恢复模式
withResume
设置为恢复模式
withoutResume
取消恢复模式
copyWithResume
创建恢复模式的副本
人工反馈相关方法
humanFeedback
获取人工反馈
withHumanFeedback
设置人工反馈
interruptMessage
获取中断消息
setInterruptMessage
设置中断消息

内部静态类 HumanFeedback,处理和存储在工作流执行过程中来自人工反馈的信息

  • Map<String, Object> data:存储人工反馈的具体数据内容,可以包含任意键值对形式的数据
  • String nextNodeId:指定下一个要执行的节点 ID
package com.alibaba.cloud.ai.graph;

import org.springframework.util.CollectionUtils;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;

import static com.alibaba.cloud.ai.graph.utils.CollectionsUtils.entryOf;
import static java.util.Collections.unmodifiableMap;
import static java.util.Optional.ofNullable;

public final class OverAllState implements Serializable {

    /**
     * Internal map storing the actual state data. All get/set operations on state values
     * go through this map.
     */
    private final Map<String, Object> data;

    /**
     * Mapping of keys to their respective update strategies. Determines how values for
     * each key should be merged or updated.
     */
    private final Map<String, KeyStrategy> keyStrategies;

    /**
     * Indicates whether this state is being used to resume a previously interrupted
     * execution. If true, certain initialization steps may be skipped.
     */
    private Boolean resume;

    /**
     * Holds optional human feedback information provided during execution. May be null if
     * no feedback was given.
     */
    private HumanFeedback humanFeedback;

    /**
     * Optional message indicating that the execution was interrupted. If non-null,
     * indicates that the graph should halt or handle the interruption.
     */
    private String interruptMessage;

    /**
     * The default key used for standard input injection into the state. Typically used
     * when initializing the state with user or external input.
     */
    public static final String DEFAULTINPUTKEY = "input";

    /**
     * Reset.
     */
    public void reset() {
       this.data.clear();
    }

    /**
     * Snap shot optional.
     * @return the optional
     */
    public Optional<OverAllState> snapShot() {
       return Optional.of(new OverAllState(new HashMap<>(this.data), new HashMap<>(this.keyStrategies), this.resume));
    }

    /**
     * Instantiates a new Over all state.
     * @param resume the is resume
     */
    public OverAllState(boolean resume) {
       this.data = new HashMap<>();
       this.keyStrategies = new HashMap<>();
       this.resume = resume;
    }

    /**
     * Instantiates a new Over all state.
     * @param data the data
     */
    public OverAllState(Map<String, Object> data) {
       this.data = new HashMap<>(data);
       this.keyStrategies = new HashMap<>();
       this.resume = false;
    }

    /**
     * Instantiates a new Over all state.
     */
    public OverAllState() {
       this.data = new HashMap<>();
       this.keyStrategies = new HashMap<>();
       this.registerKeyAndStrategy(OverAllState.DEFAULTINPUTKEY, new ReplaceStrategy());
       this.resume = false;
    }

    /**
     * Instantiates a new Over all state.
     * @param data the data
     * @param keyStrategies the key strategies
     * @param resume the resume
     */
    protected OverAllState(Map<String, Object> data, Map<String, KeyStrategy> keyStrategies, Boolean resume) {
       this.data = data;
       this.keyStrategies = keyStrategies;
       this.registerKeyAndStrategy(OverAllState.DEFAULTINPUTKEY, new ReplaceStrategy());
       this.resume = resume;
    }

    /**
     * Interrupt message string.
     * @return the string
     */
    public String interruptMessage() {
       return interruptMessage;
    }

    /**
     * Sets interrupt message.
     * @param interruptMessage the interrupt message
     */
    public void setInterruptMessage(String interruptMessage) {
       this.interruptMessage = interruptMessage;
    }

    /**
     * With human feedback.
     * @param humanFeedback the human feedback
     */
    public void withHumanFeedback(HumanFeedback humanFeedback) {
       this.humanFeedback = humanFeedback;
    }

    /**
     * Human feedback human feedback.
     * @return the human feedback
     */
    public HumanFeedback humanFeedback() {
       return this.humanFeedback;
    }

    /**
     * Copy with resume over all state.
     * @return the over all state
     */
    public OverAllState copyWithResume() {
       return new OverAllState(this.data, this.keyStrategies, true);
    }

    /**
     * With resume.
     */
    public void withResume() {
       this.resume = true;
    }

    /**
     * Without resume.
     */
    public void withoutResume() {
       this.resume = false;
    }

    /**
     * Is resume boolean.
     * @return the boolean
     */
    public boolean isResume() {
       return this.resume;
    }

    /**
     * Clears all data in the current state, leaving key strategies, resume flag, and
     * human feedback intact.
     */
    public void clear() {
       this.data.clear();
    }

    /**
     * Replaces the current state's contents with the provided state.
     * <p>
     * This method effectively copies all data, key strategies, resume flag, and human
     * feedback from the provided state to this state.
     * @param overAllState the state to copy from
     */
    public void cover(OverAllState overAllState) {
       this.keyStrategies.clear();
       this.keyStrategies.putAll(overAllState.keyStrategies());
       this.data.clear();
       this.data.putAll(overAllState.data());
       this.resume = overAllState.resume;
       this.humanFeedback = overAllState.humanFeedback;
    }

    /**
     * Inputs over all state.
     * @param input the input
     * @return the over all state
     */
    public OverAllState input(Map<String, Object> input) {
       if (input == null) {
          withResume();
          return this;
       }

       if (CollectionUtils.isEmpty(input)) {
          return this;
       }

       Map<String, KeyStrategy> keyStrategies = keyStrategies();
       input.keySet().stream().filter(key -> keyStrategies.containsKey(key)).forEach(key -> {
          this.data.put(key, keyStrategies.get(key).apply(value(key, null), input.get(key)));
       });
       return this;
    }

    /**
     * Add key and strategy over all state.
     * @param key the key
     * @param strategy the strategy
     * @return the over all state
     */
    public OverAllState registerKeyAndStrategy(String key, KeyStrategy strategy) {
       this.keyStrategies.put(key, strategy);
       return this;
    }

    /**
     * Register key and strategy over all state.
     * @param keyStrategies the key strategies
     * @return the over all state
     */
    public OverAllState registerKeyAndStrategy(Map<String, KeyStrategy> keyStrategies) {
       this.keyStrategies.putAll(keyStrategies);
       return this;
    }

    /**
     * Is contain strategy boolean.
     * @param key the key
     * @return the boolean
     */
    public boolean containStrategy(String key) {
       return this.keyStrategies.containsKey(key);
    }

    /**
     * Update state map.
     * @param partialState the partial state
     * @return the map
     */
    public Map<String, Object> updateState(Map<String, Object> partialState) {
       Map<String, KeyStrategy> keyStrategies = keyStrategies();
       partialState.keySet().stream().filter(key -> keyStrategies.containsKey(key)).forEach(key -> {
          this.data.put(key, keyStrategies.get(key).apply(value(key, null), partialState.get(key)));
       });
       return data();
    }

    /**
     * Updates the internal state based on a schema-defined strategy.
     * <p>
     * This method first validates the input state, then updates the partial state
     * according to the provided key strategies. The updated state is formed by merging
     * the original state and the modified partial state, removing any null values in the
     * process. The resulting entries are then used to update the internal data map.
     * @param state the base state to update; must not be null
     * @param partialState the partial state containing updates; may be null or empty
     * @param keyStrategies the mapping of keys to update strategies; used to transform
     * values
     */
    public void updateStateBySchema(Map<String, Object> state, Map<String, Object> partialState,
          Map<String, KeyStrategy> keyStrategies) {
       updateState(updateState(state, partialState, keyStrategies));
    }

    /**
     * Key verify boolean.
     * @return the boolean
     */
    protected boolean keyVerify() {
       return hasCommonKey(this.data, getKeyStrategies());
    }

    private Map<?, ?> getKeyStrategies() {
       return this.keyStrategies;
    }

    private boolean hasCommonKey(Map<?, ?> map1, Map<?, ?> map2) {
       Set<?> keys1 = map1.keySet();
       for (Object key : map2.keySet()) {
          if (keys1.contains(key)) {
             return true;
          }
       }
       return false;
    }

    /**
     * Updates a state with the provided partial state. The merge function is used to
     * merge the current state value with the new value.
     * @param state the current state
     * @param partialState the partial state to update from
     * @return the updated state
     * @throws NullPointerException if state is null
     */
    public static Map<String, Object> updateState(Map<String, Object> state, Map<String, Object> partialState) {
       Objects.requireNonNull(state, "state cannot be null");
       if (partialState == null || partialState.isEmpty()) {
          return state;
       }

       return Stream.concat(state.entrySet().stream(), partialState.entrySet().stream())
          .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, OverAllState::mergeFunction));
    }

    /**
     * Update state map.
     * @param state the state
     * @param partialState the partial state
     * @param keyStrategies the key strategies
     * @return the map
     */
    public static Map<String, Object> updateState(Map<String, Object> state, Map<String, Object> partialState,
          Map<String, KeyStrategy> keyStrategies) {
       Objects.requireNonNull(state, "state cannot be null");
       if (partialState == null || partialState.isEmpty()) {
          return state;
       }

       Map<String, Object> updatedPartialState = updatePartialStateFromSchema(state, partialState, keyStrategies);

       return Stream.concat(state.entrySet().stream(), updatedPartialState.entrySet().stream())
          .collect(toMapRemovingNulls(Map.Entry::getKey, Map.Entry::getValue, (currentValue, newValue) -> newValue));
    }

    /**
     * Updates the partial state from a schema using channels.
     * @param state The current state as a map of key-value pairs.
     * @param partialState The partial state to be updated.
     * @param keyStrategies A map of channel names to their implementations.
     * @return An updated version of the partial state after applying the schema and
     * channels.
     */
    private static Map<String, Object> updatePartialStateFromSchema(Map<String, Object> state,
          Map<String, Object> partialState, Map<String, KeyStrategy> keyStrategies) {
       if (keyStrategies == null || keyStrategies.isEmpty()) {
          return partialState;
       }
       return partialState.entrySet().stream().map(entry -> {

          KeyStrategy channel = keyStrategies.get(entry.getKey());
          if (channel != null) {
             Object newValue = channel.apply(state.get(entry.getKey()), entry.getValue());
             return entryOf(entry.getKey(), newValue);
          }

          return entry;
       }).collect(toMapAllowingNulls(Map.Entry::getKey, Map.Entry::getValue));
    }

    private static <T, K, U> Collector<T, ?, Map<K, U>> toMapRemovingNulls(Function<? super T, ? extends K> keyMapper,
          Function<? super T, ? extends U> valueMapper, BinaryOperator<U> mergeFunction) {
       return Collector.of(HashMap::new, (map, element) -> {
          K key = keyMapper.apply(element);
          U value = valueMapper.apply(element);
          if (value == null) {
             map.remove(key);
          }
          else {
             map.merge(key, value, mergeFunction);
          }
       }, (map1, map2) -> {
          map2.forEach((key, value) -> {
             if (value != null) {
                map1.merge(key, value, mergeFunction);
             }
          });
          return map1;
       }, Collector.Characteristics.UNORDERED);
    }

    private static <T, K, U> Collector<T, ?, Map<K, U>> toMapAllowingNulls(Function<? super T, ? extends K> keyMapper,
          Function<? super T, ? extends U> valueMapper) {
       return Collector.of(HashMap::new,
             (map, element) -> map.put(keyMapper.apply(element), valueMapper.apply(element)), (map1, map2) -> {
                map1.putAll(map2);
                return map1;
             }, Collector.Characteristics.UNORDERED);
    }

    /**
     * Merges the current value with the new value using the appropriate merge function.
     * @param currentValue the current value
     * @param newValue the new value
     * @return the merged value
     */
    private static Object mergeFunction(Object currentValue, Object newValue) {
       return newValue;
    }

    /**
     * Key strategies map.
     * @return the map
     */
    public Map<String, KeyStrategy> keyStrategies() {
       return keyStrategies;
    }

    /**
     * Data map.
     * @return the map
     */
    public final Map<String, Object> data() {
       return unmodifiableMap(data);
    }

    /**
     * Value optional.
     * @param <T> the type parameter
     * @param key the key
     * @return the optional
     */
    public final <T> Optional<T> value(String key) {
       return ofNullable((T) data().get(key));
    }

    /**
     * Value optional.
     * @param <T> the type parameter
     * @param key the key
     * @param type the type
     * @return the optional
     */
    public final <T> Optional<T> value(String key, Class<T> type) {
       if (type != null) {
          return ofNullable(type.cast(data().get(key)));
       }
       return value(key);
    }

    /**
     * Value t.
     * @param <T> the type parameter
     * @param key the key
     * @param defaultValue the default value
     * @return the t
     */
    public final <T> T value(String key, T defaultValue) {
       return (T) value(key).orElse(defaultValue);
    }

    /**
     * The type Human feedback.
     */
    public static class HumanFeedback implements Serializable {

       private Map<String, Object> data;

       private String nextNodeId;

       private String currentNodeId;

       /**
        * Instantiates a new Human feedback.
        * @param data the data
        * @param nextNodeId the next node id
        */
       public HumanFeedback(Map<String, Object> data, String nextNodeId) {
          this.data = data;
          this.nextNodeId = nextNodeId;
       }

       /**
        * Data map.
        * @return the map
        */
       public Map<String, Object> data() {
          return data;
       }

       /**
        * Next node id string.
        * @return the string
        */
       public String nextNodeId() {
          return nextNodeId;
       }

       /**
        * Sets data.
        * @param data the data
        */
       public void setData(Map<String, Object> data) {
          this.data = data;
       }

       /**
        * Sets next node id.
        * @param nextNodeId the next node id
        */
       public void setNextNodeId(String nextNodeId) {
          this.nextNodeId = nextNodeId;
       }

    }

    @Override
    public String toString() {
       return "OverAllState{" + "data=" + data + ", resume=" + resume + ", humanFeedback=" + humanFeedback
             + ", interruptMessage='" + interruptMessage + '\'' + '}';
    }

}

RunnableConfig

运行配置类

字段名称
字段类型
描述
threadId
String
线程ID
checkPointId
String
检查点ID
nextNode
String
下一个要执行的节点ID
streamMode
CompiledGraph.StreamMode
编译图的流模式,详情可见CompiledGraph类说明
metadata
Map
存储自定义元数据
interruptedNodes
Map
存储被中断的节点信息
package com.alibaba.cloud.ai.graph;

import com.alibaba.cloud.ai.graph.internal.node.ParallelNode;

import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;

import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.Optional.ofNullable;

/**
 * A final class representing configuration for a runnable task. This class holds various
 * parameters such as thread ID, checkpoint ID, next node, and stream mode, providing
 * methods to modify these parameters safely without permanently altering the original
 * configuration.
 */
public final class RunnableConfig implements HasMetadata<RunnableConfig.Builder> {

    private final String threadId;

    private final String checkPointId;

    private final String nextNode;

    private final CompiledGraph.StreamMode streamMode;

    private final Map<String, Object> metadata;

    private final Map<String, Object> interruptedNodes;

    /**
     * Returns the stream mode of the compiled graph.
     * @return {@code StreamMode} representing the current stream mode.
     */
    public CompiledGraph.StreamMode streamMode() {
       return streamMode;
    }

    /**
     * Returns the thread ID as an {@link Optional}.
     * @return the thread ID wrapped in an {@code Optional}, or an empty {@code Optional}
     * if no thread ID is set.
     */
    public Optional<String> threadId() {
       return ofNullable(threadId);
    }

    /**
     * Returns the current {@code checkPointId} wrapped in an {@link Optional}.
     * @return an {@link Optional} containing the {@code checkPointId}, or
     * {@link Optional#empty()} if it is null.
     */
    public Optional<String> checkPointId() {
       return ofNullable(checkPointId);
    }

    /**
     * Returns an {@code Optional} describing the next node in the sequence, or an empty
     * {@code Optional} if there is no such node.
     * @return an {@code Optional} describing the next node, or an empty {@code Optional}
     */
    public Optional<String> nextNode() {
       return ofNullable(nextNode);
    }

    /**
     * Checks if a node is marked as interrupted in the metadata.
     * @param nodeId the ID of the node to check for interruption status
     * @return true if the node is marked as interrupted, false otherwise
     */
    public boolean isInterrupted(String nodeId) {
       return interruptData(HasMetadata.formatNodeId(nodeId)).map(value -> Boolean.TRUE.equals(value)).orElse(false);
    }

    /**
     * Marks a node as not interrupted by setting its value to false in the metadata.
     * @param nodeId the ID of the node to mark as not interrupted
     * @return a new {@code RunnableConfig} instance with the updated metadata
     */
    public void withNodeResumed(String nodeId) {
       String formattedNodeId = HasMetadata.formatNodeId(nodeId);
       interruptedNodes.put(formattedNodeId, false);
    }

    /**
     * Removes the interrupted marker for a specific node by removing its entry from the
     * metadata.
     * @param nodeId the ID of the node to remove the interrupted marker for
     * @return a new {@code RunnableConfig} instance with the updated metadata
     */
    public void removeInterrupted(String nodeId) {
       String formattedNodeId = HasMetadata.formatNodeId(nodeId);
       if (interruptedNodes == null || !interruptedNodes.containsKey(formattedNodeId)) {
          return; // No change needed if the marker doesn't exist
       }
       interruptedNodes.remove(formattedNodeId);
    }

    /**
     * Marks a node as interrupted by adding it to the metadata with a formatted key. The
     * node ID is formatted using {@link #formatNodeId(String)} and associated with a
     * value of {@code true} in the metadata map.
     * @param nodeId the ID of the node to mark as interrupted; must not be null
     * @return this {@code Builder} instance for method chaining
     * @throws NullPointerException if nodeId is null
     */
    public void markNodeAsInterrupted(String nodeId) {
       interruptedNodes.put(HasMetadata.formatNodeId(nodeId), true);
    }

    /**
     * Create a new RunnableConfig with the same attributes as this one but with a
     * different {@link CompiledGraph.StreamMode}.
     * @param streamMode the new stream mode
     * @return a new RunnableConfig with the updated stream mode
     */
    public RunnableConfig withStreamMode(CompiledGraph.StreamMode streamMode) {
       if (this.streamMode == streamMode) {
          return this;
       }

       return RunnableConfig.builder(this).streamMode(streamMode).build();
    }

    /**
     * Updates the checkpoint ID of the configuration.
     * @param checkPointId The new checkpoint ID to set.
     * @return A new instance of {@code RunnableConfig} with the updated checkpoint ID, or
     * the current instance if no change is needed.
     */
    public RunnableConfig withCheckPointId(String checkPointId) {
       if (Objects.equals(this.checkPointId, checkPointId)) {
          return this;
       }
       return RunnableConfig.builder(this).checkPointId(checkPointId).build();

    }

    /**
     * Retrieves interrupt data associated with the specified key.
     * @param key the key for which to retrieve interrupt data; may be null
     * @return an Optional containing the interrupt data if present, or an empty Optional
     * if the key is null or no data is found
     */
    public Optional<Object> interruptData(String key) {
       if (key == null) {
          return Optional.empty();
       }
       return ofNullable(interruptedNodes).map(m -> m.get(key));
    }

    /**
     * return metadata value for key
     * @param key given metadata key
     * @return metadata value for key if any
     */
    @Override
    public Optional<Object> metadata(String key) {
       if (key == null) {
          return Optional.empty();
       }
       return ofNullable(metadata).map(m -> m.get(key));
    }

    /**
     * Creates a new instance of the {@link Builder} class.
     * @return A new {@code Builder} object.
     */
    public static Builder builder() {
       return new Builder();
    }

    /**
     * Creates a new {@code Builder} instance with the specified {@link RunnableConfig}.
     * @param config The configuration for the {@code Builder}.
     * @return A new {@code Builder} instance.
     */
    public static Builder builder(RunnableConfig config) {
       return new Builder(config);
    }

    /**
     * A builder pattern class for constructing {@link RunnableConfig} objects. This class
     * provides a fluent interface to set various properties of a {@link RunnableConfig}
     * object and then build the final configuration.
     */
    public static class Builder extends HasMetadata.Builder<Builder> {

       private String threadId;

       private String checkPointId;

       private String nextNode;

       private CompiledGraph.StreamMode streamMode = CompiledGraph.StreamMode.VALUES;

       /**
        * Constructs a new instance of the {@link Builder} with default configuration
        * settings. Initializes a new {@link RunnableConfig} object for configuration
        * purposes.
        */
       Builder() {
       }

       /**
        * Initializes a new instance of the {@code Builder} class with the specified
        * {@link RunnableConfig}.
        * @param config The configuration to be used for initialization.
        */
       Builder(RunnableConfig config) {
          super(requireNonNull(config, "config cannot be null!").metadata);
          this.threadId = config.threadId;
          this.checkPointId = config.checkPointId;
          this.nextNode = config.nextNode;
          this.streamMode = config.streamMode;
       }

       /**
        * Sets the ID of the thread.
        * @param threadId the ID of the thread to set
        * @return a reference to this {@code Builder} object so that method calls can be
        * chained together
        */
       public Builder threadId(String threadId) {
          this.threadId = threadId;
          return this;
       }

       /**
        * Sets the checkpoint ID for the configuration.
        * @param {@code checkPointId} - the ID of the checkpoint to be set
        * @return {@literal this} - a reference to the current `Builder` instance
        */
       public Builder checkPointId(String checkPointId) {
          this.checkPointId = checkPointId;
          return this;
       }

       /**
        * Sets the next node in the configuration and returns this builder for method
        * chaining.
        * @param nextNode The next node to be set.
        * @return This builder instance, allowing for method chaining.
        */
       public Builder nextNode(String nextNode) {
          this.nextNode = nextNode;
          return this;
       }

       /**
        * Sets the stream mode of the configuration.
        * @param streamMode The {@link CompiledGraph.StreamMode} to set.
        * @return A reference to this builder for method chaining.
        */
       public Builder streamMode(CompiledGraph.StreamMode streamMode) {
          this.streamMode = streamMode;
          return this;
       }

       /**
        * Adds a custom {@link Executor} for a specific parallel node.
        * <p>
        * This allows you to control the execution of branches within a parallel node.
        * When a parallel node is executed, it will look for an executor in the
        * {@link RunnableConfig} metadata. If found, it will be used to run the parallel
        * branches concurrently.
        * @param nodeId the ID of the parallel node.
        * @param executor the {@link Executor} to use for the parallel node.
        * @return this {@code Builder} instance for method chaining.
        */
       public Builder addParallelNodeExecutor(String nodeId, Executor executor) {
          return addMetadata(ParallelNode.formatNodeId(nodeId), requireNonNull(executor, "executor cannot be null!"));
       }

       /**
        * Constructs and returns the configured {@code RunnableConfig} object.
        * @return the configured {@code RunnableConfig} object
        */
       public RunnableConfig build() {
          return new RunnableConfig(this);
       }

    }

    /**
     * Creates a new instance of {@code RunnableConfig} as a copy of the provided
     * {@code config}.
     * @param builder The configuration builder.
     */
    private RunnableConfig(Builder builder) {
       this.threadId = builder.threadId;
       this.checkPointId = builder.checkPointId;
       this.nextNode = builder.nextNode;
       this.streamMode = builder.streamMode;
       this.metadata = ofNullable(builder.metadata()).map(Map::copyOf).orElse(null);
       this.interruptedNodes = new ConcurrentHashMap<>();
    }

    @Override
    public String toString() {
       return format("RunnableConfig{ threadId=%s, checkPointId=%s, nextNode=%s, streamMode=%s }", threadId,
             checkPointId, nextNode, streamMode);
    }

}

StateGraph

用于表示和构建基于状态的图结构工作流,有如下功能

  • 图结构定义:提供构建有向图的 API,包括节点和边的定义
  • 状态管理:与 OverAllState 配合管理图执行过程中的状态
  • 工作流编排:支持定义复杂的工作流执行逻辑,包括条件边和子图
字段名称
字段类型
描述
START
String
图的起始节点标识常量("START")
END
String
图的结束节点标识常量("END")
ERROR
String
图的错误节点标识常量("ERROR")
NODEBEFORE
String
节点执行前钩子标识常量("NODEBEFORE")
NODEAFTER
String
节点执行后钩子标识常量("NODEAFTER")
nodes
Nodes
图中所有节点的容器
edges
Edges
图中所有边的容器
overAllStateFactory
OverAllStateFactory
创建整体状态实例的工厂(已废弃)
keyStrategyFactory
KeyStrategyFactory
供键策略的工厂
name
String
图的名称
stateSerializer
PlainTextStateSerializer
内部类,基于Jackson的状态序列化器

对外暴露的方法


方法名称
描述
构造

StateGraph
支持五种方式构造
- 无参
- KeyStrategyFactory keyStrategyFactory:带键策略工厂
- (String name, KeyStrategyFactory keyStrategyFactory):带名称和键策略工厂
- (KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer):带键策略工厂、序列化器
- (String name, KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) :带键策略工厂、名称、序列化器
节点管理

addNode
- (String id, Node node):添加节点实例
- (String id, AsyncNodeAction action):添加普通节点
- (String id, AsyncNodeActionWithConfig actionWithConfig):添加带配置的节点
- (String id, AsyncCommandAction action, Map mappings):添加条件边节点
- (String id, StateGraph subGraph):添加子图节点
- (String id, CompiledGraph subGraph):添加已编译的子图节点
边管理
addEdge
(String sourceId, String targetId):添加普通边
addConditionalEdges

- (String sourceId, AsyncEdgeAction condition, Map mappings):添加条件边,使用EdgeAction
- (String sourceId, AsyncCommandAction condition, Map mappings):添加条件边
编译和验证
validateGraph
验证图结构
compile
- 无参:使用默认配置编译图
- (CompileConfig config):由编译配置编译图
可视化方法
getGraph
- (GraphRepresentation.Type type, String title, boolean printConditionalEdges):生成图表示
- (GraphRepresentation.Type type, String title):生成图表示,含条件边
- (GraphRepresentation.Type type):使用图表示,图名称作为标题

静态内部类 Nodes,管理图中的节点集合

  • anyMatchById(String id):检查是否存在指定 ID 的节点
  • onlySubStateGraphNodes():获取所有子图节点
  • exceptSubStateGraphNodes():获取移除子图之外的所有节点

静态内部类 Edges,管理图中边集合

  • edgeBySourceId(String sourceId):根据源节点 ID 查找边
  • edgesByTargetId(String targetId):根据目标节点 ID 查找边
package com.alibaba.cloud.ai.graph;

import com.alibaba.cloud.ai.graph.action.AsyncCommandAction;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import com.alibaba.cloud.ai.graph.checkpoint.config.SaverConfig;
import com.alibaba.cloud.ai.graph.checkpoint.constant.SaverConstant;
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;

import com.alibaba.cloud.ai.graph.exception.Errors;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.internal.edge.Edge;
import com.alibaba.cloud.ai.graph.internal.edge.EdgeCondition;
import com.alibaba.cloud.ai.graph.internal.edge.EdgeValue;
import com.alibaba.cloud.ai.graph.internal.node.CommandNode;
import com.alibaba.cloud.ai.graph.internal.node.Node;
import com.alibaba.cloud.ai.graph.internal.node.SubCompiledGraphNode;
import com.alibaba.cloud.ai.graph.internal.node.SubStateGraphNode;
import com.alibaba.cloud.ai.graph.serializer.StateSerializer;
import com.alibaba.cloud.ai.graph.serializer.plaintext.PlainTextStateSerializer;
import com.alibaba.cloud.ai.graph.serializer.plaintext.jackson.JacksonStateSerializer;
import com.alibaba.cloud.ai.graph.state.AgentStateFactory;
import com.fasterxml.jackson.databind.ObjectMapper;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.LinkedHashSet;

/**
 * Represents a state graph with nodes and edges.
 */
public class StateGraph {

    /**
     * Constant representing the END of the graph.
     */
    public static final String END = "END";

    /**
     * Constant representing the START of the graph.
     */
    public static final String START = "START";

    /**
     * Constant representing the ERROR of the graph.
     */
    public static final String ERROR = "ERROR";

    /**
     * Constant representing the NODEBEFORE of the graph.
     */
    public static final String NODEBEFORE = "NODEBEFORE";

    /**
     * Constant representing the NODEAFTER of the graph.
     */
    public static final String NODEAFTER = "NODEAFTER";

    /**
     * Collection of nodes in the graph.
     */
    final Nodes nodes = new Nodes();

    /**
     * Collection of edges in the graph.
     */
    final Edges edges = new Edges();

    /**
     * Factory for creating overall state instances.
     */
    private OverAllStateFactory overAllStateFactory;

    /**
     * Factory for providing key strategies.
     */
    private KeyStrategyFactory keyStrategyFactory;

    /**
     * Name of the graph.
     */
    private String name;

    /**
     * Serializer for the state.
     */
    private final PlainTextStateSerializer stateSerializer;

    /**
     * Jackson-based serializer for state.
     */
    static class JacksonSerializer extends JacksonStateSerializer {

       /**
        * Instantiates a new Jackson serializer.
        */
       public JacksonSerializer() {
          super(OverAllState::new);
       }

       /**
        * Gets object mapper.
        * @return the object mapper
        */
       ObjectMapper getObjectMapper() {
          return objectMapper;
       }

    }

    /**
     * Constructs a StateGraph with the specified name, key strategy factory, and state
     * serializer.
     * @param name the name of the graph
     * @param keyStrategyFactory the factory for providing key strategies
     * @param stateSerializer the plain text state serializer to use
     */
    public StateGraph(String name, KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) {
       this.name = name;
       this.keyStrategyFactory = keyStrategyFactory;
       this.stateSerializer = stateSerializer;
    }

    public StateGraph(KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) {
       this.keyStrategyFactory = keyStrategyFactory;
       this.stateSerializer = stateSerializer;
    }

    /**
     * Constructs a StateGraph with the given key strategy factory and name.
     * @param keyStrategyFactory the factory for providing key strategies
     * @param name the name of the graph
     */
    public StateGraph(String name, KeyStrategyFactory keyStrategyFactory) {
       this.keyStrategyFactory = keyStrategyFactory;
       this.name = name;
       this.stateSerializer = new JacksonSerializer();
    }

    /**
     * Constructs a StateGraph with the provided key strategy factory.
     * @param keyStrategyFactory the factory for providing key strategies
     */
    public StateGraph(KeyStrategyFactory keyStrategyFactory) {
       this.keyStrategyFactory = keyStrategyFactory;
       this.stateSerializer = new JacksonSerializer();
    }

    /**
     * Deprecated constructor that initializes a StateGraph with the specified name,
     * overall state factory, and state serializer.
     * @param name the name of the graph
     * @param overAllStateFactory the factory for creating overall state instances
     * @param plainTextStateSerializer the plain text state serializer to use
     */
    @Deprecated
    public StateGraph(String name, OverAllStateFactory overAllStateFactory,
          PlainTextStateSerializer plainTextStateSerializer) {
       this.name = name;
       this.overAllStateFactory = overAllStateFactory;
       this.stateSerializer = plainTextStateSerializer;
    }

    /**
     * Deprecated constructor that initializes a StateGraph with the specified name and
     * overall state factory.
     * @param name the name of the graph
     * @param overAllStateFactory the factory for creating overall state instances
     */
    @Deprecated
    public StateGraph(String name, OverAllStateFactory overAllStateFactory) {
       this.name = name;
       this.overAllStateFactory = overAllStateFactory;
       this.stateSerializer = new JacksonSerializer();
    }

    /**
     * Deprecated constructor that initializes a StateGraph with the provided overall
     * state factory.
     * @param overAllStateFactory the factory for creating overall state instances
     */
    @Deprecated
    public StateGraph(OverAllStateFactory overAllStateFactory) {
       this.overAllStateFactory = overAllStateFactory;
       this.stateSerializer = new JacksonSerializer();
    }

    /**
     * Deprecated constructor that initializes a StateGraph with the provided overall
     * state factory and state serializer.
     * @param overAllStateFactory the factory for creating overall state instances
     * @param plainTextStateSerializer the plain text state serializer to use
     */
    @Deprecated
    public StateGraph(OverAllStateFactory overAllStateFactory, PlainTextStateSerializer plainTextStateSerializer) {
       this.overAllStateFactory = overAllStateFactory;
       this.stateSerializer = plainTextStateSerializer;
    }

    /**
     * Default constructor that initializes a StateGraph with a Gson-based state
     * serializer.
     */
    public StateGraph() {
       this.stateSerializer = new JacksonSerializer();
       this.keyStrategyFactory = HashMap::new;
    }

    /**
     * Gets the name of the graph.
     * @return the name
     */
    public String getName() {
       return name;
    }

    /**
     * Gets the state serializer used by this graph.
     * @return the state serializer
     */
    public StateSerializer<OverAllState> getStateSerializer() {
       return stateSerializer;
    }

    /**
     * Gets the state factory associated with this graph's state serializer.
     * @return the state factory
     */
    public final AgentStateFactory<OverAllState> getStateFactory() {
       return stateSerializer.stateFactory();
    }

    /**
     * Gets the overall state factory.
     * @return the overall state factory
     */
    @Deprecated
    public final OverAllStateFactory getOverAllStateFactory() {
       return overAllStateFactory;
    }

    /**
     * Gets the key strategy factory.
     * @return the key strategy factory
     */
    public final KeyStrategyFactory getKeyStrategyFactory() {
       return keyStrategyFactory;
    }

    /**
     * Adds a commandNode to the graph.
     * @param id the identifier of the node
     * @param action AsyncCommandAction action
     * @param mappings the mappings to be used for conditional edges
     * @return this state graph instance
     * @throws GraphStateException if the node identifier is invalid or the node already
     * exists
     */
    public StateGraph addNode(String id, AsyncCommandAction action, Map<String, String> mappings)
          throws GraphStateException {

       return addNode(id, new CommandNode(id, action, mappings));
    }

    /**
     * Adds a node to the graph.
     * @param id the identifier of the node
     * @param action the asynchronous node action to be performed by the node
     * @return this state graph instance
     * @throws GraphStateException if the node identifier is invalid or the node already
     * exists
     */
    public StateGraph addNode(String id, AsyncNodeAction action) throws GraphStateException {
       return addNode(id, AsyncNodeActionWithConfig.of(action));
    }

    /**
     * Adds a node to the graph with the specified action and configuration.
     * @param id the identifier of the node
     * @param actionWithConfig the action to be performed by the node
     * @return this state graph instance
     * @throws GraphStateException if the node identifier is invalid or the node already
     * exists
     */
    public StateGraph addNode(String id, AsyncNodeActionWithConfig actionWithConfig) throws GraphStateException {
       Node node = new Node(id, (config) -> actionWithConfig);
       return addNode(id, node);
    }

    /**
     * Adds a node to the graph with the specified identifier and node instance.
     * @param id the identifier of the node
     * @param node the node to be added
     * @return this state graph instance
     * @throws GraphStateException if the node identifier is invalid or the node already
     * exists
     */
    public StateGraph addNode(String id, Node node) throws GraphStateException {
       if (Objects.equals(node.id(), END)) {
          throw Errors.invalidNodeIdentifier.exception(END);
       }
       if (!Objects.equals(node.id(), id)) {
          throw Errors.invalidNodeIdentifier.exception(node.id(), id);
       }

       if (nodes.elements.contains(node)) {
          throw Errors.duplicateNodeError.exception(id);
       }

       nodes.elements.add(node);
       return this;
    }

    /**
     * Adds a subgraph to the state graph by creating a node with the specified
     * identifier. This implies that the subgraph shares the same state with the parent
     * graph.
     * @param id the identifier of the node representing the subgraph
     * @param subGraph the compiled subgraph to be added
     * @return this state graph instance
     * @throws GraphStateException if the node identifier is invalid or the node already
     * exists
     */
    public StateGraph addNode(String id, CompiledGraph subGraph) throws GraphStateException {
       if (Objects.equals(id, END)) {
          throw Errors.invalidNodeIdentifier.exception(END);
       }

       var node = new SubCompiledGraphNode(id, subGraph);

       if (nodes.elements.contains(node)) {
          throw Errors.duplicateNodeError.exception(id);
       }

       nodes.elements.add(node);
       return this;
    }

    /**
     * Adds a subgraph to the state graph by creating a node with the specified
     * identifier. This implies that the subgraph will share the same state with the
     * parent graph and will be compiled when the parent is compiled.
     * @param id the identifier of the node representing the subgraph
     * @param subGraph the subgraph to be added; it will be compiled during compilation of
     * the parent
     * @return this state graph instance
     * @throws GraphStateException if the node identifier is invalid or the node already
     * exists
     */
    public StateGraph addNode(String id, StateGraph subGraph) throws GraphStateException {
       if (Objects.equals(id, END)) {
          throw Errors.invalidNodeIdentifier.exception(END);
       }

       subGraph.validateGraph();

       var node = new SubStateGraphNode(id, subGraph);

       if (nodes.elements.contains(node)) {
          throw Errors.duplicateNodeError.exception(id);
       }

       nodes.elements.add(node);
       return this;
    }

    /**
     * Adds an edge to the graph between the specified source and target nodes.
     * @param sourceId the identifier of the source node
     * @param targetId the identifier of the target node
     * @return this state graph instance
     * @throws GraphStateException if the edge identifier is invalid or the edge already
     * exists
     */
    public StateGraph addEdge(String sourceId, String targetId) throws GraphStateException {
       if (Objects.equals(sourceId, END)) {
          throw Errors.invalidEdgeIdentifier.exception(END);
       }

       var newEdge = new Edge(sourceId, new EdgeValue(targetId));

       int index = edges.elements.indexOf(newEdge);
       if (index >= 0) {
          var newTargets = new ArrayList<>(edges.elements.get(index).targets());
          newTargets.add(newEdge.target());
          edges.elements.set(index, new Edge(sourceId, newTargets));
       }
       else {
          edges.elements.add(newEdge);
       }

       return this;
    }

    /**
     * Adds conditional edges to the graph based on the provided condition and mappings.
     * @param sourceId the identifier of the source node
     * @param condition the command action used to determine the target node
     * @param mappings the mappings of conditions to target nodes
     * @return this state graph instance
     * @throws GraphStateException if the edge identifier is invalid, the mappings are
     * empty, or the edge already exists
     */
    public StateGraph addConditionalEdges(String sourceId, AsyncCommandAction condition, Map<String, String> mappings)
          throws GraphStateException {
       if (Objects.equals(sourceId, END)) {
          throw Errors.invalidEdgeIdentifier.exception(END);
       }
       if (mappings == null || mappings.isEmpty()) {
          throw Errors.edgeMappingIsEmpty.exception(sourceId);
       }

       var newEdge = new Edge(sourceId, new EdgeValue(new EdgeCondition(condition, mappings)));

       if (edges.elements.contains(newEdge)) {
          throw Errors.duplicateConditionalEdgeError.exception(sourceId);
       }
       else {
          edges.elements.add(newEdge);
       }
       return this;
    }

    /**
     * Adds conditional edges to the graph based on the provided edge action condition and
     * mappings.
     * @param sourceId the identifier of the source node
     * @param condition the edge action used to determine the target node
     * @param mappings the mappings of conditions to target nodes
     * @return this state graph instance
     * @throws GraphStateException if the edge identifier is invalid, the mappings are
     * empty, or the edge already exists
     */
    public StateGraph addConditionalEdges(String sourceId, AsyncEdgeAction condition, Map<String, String> mappings)
          throws GraphStateException {
       return addConditionalEdges(sourceId, AsyncCommandAction.of(condition), mappings);
    }

    /**
     * Validates the structure of the graph ensuring all connections are valid.
     * @throws GraphStateException if there are errors related to the graph state
     */
    void validateGraph() throws GraphStateException {
       var edgeStart = edges.edgeBySourceId(START).orElseThrow(Errors.missingEntryPoint::exception);

       edgeStart.validate(nodes);

       validateNode(nodes);

       for (Edge edge : edges.elements) {
          edge.validate(nodes);
       }
    }

    private void validateNode(Nodes nodes) throws GraphStateException {
       List<CommandNode> commandNodeList = nodes.elements.stream().filter(node -> {
          return node instanceof CommandNode commandNode;
       }).map(node -> (CommandNode) node).toList();
       for (CommandNode commandNode : commandNodeList) {
          for (String key : commandNode.getMappings().keySet()) {
             if (!nodes.anyMatchById(key)) {
                throw Errors.missingNodeInEdgeMapping.exception(commandNode.id(), key);
             }
          }
       }
    }

    /**
     * Compiles the state graph into a compiled graph using the provided configuration.
     * @param config the compile configuration
     * @return a compiled graph
     * @throws GraphStateException if there are errors related to the graph state
     */
    public CompiledGraph compile(CompileConfig config) throws GraphStateException {
       Objects.requireNonNull(config, "config cannot be null");

       validateGraph();

       return new CompiledGraph(this, config);
    }

    /**
     * Compiles the state graph into a compiled graph using a default configuration with
     * memory saver.
     * @return a compiled graph
     * @throws GraphStateException if there are errors related to the graph state
     */
    public CompiledGraph compile() throws GraphStateException {
       SaverConfig saverConfig = SaverConfig.builder().register(SaverConstant.MEMORY, new MemorySaver()).build();
       return compile(CompileConfig.builder().saverConfig(saverConfig).build());
    }

    /**
     * Generates a drawable graph representation of the state graph.
     * @param type the type of graph representation to generate
     * @param title the title of the graph
     * @param printConditionalEdges whether to include conditional edges in the output
     * @return a diagram code of the state graph
     */
    public GraphRepresentation getGraph(GraphRepresentation.Type type, String title, boolean printConditionalEdges) {
       String content = type.generator.generate(nodes, edges, title, printConditionalEdges);

       return new GraphRepresentation(type, content);
    }

    /**
     * Generates a drawable graph representation of the state graph with conditional edges
     * included.
     * @param type the type of graph representation to generate
     * @param title the title of the graph
     * @return a diagram code of the state graph
     */
    public GraphRepresentation getGraph(GraphRepresentation.Type type, String title) {
       String content = type.generator.generate(nodes, edges, title, true);

       return new GraphRepresentation(type, content);
    }

    /**
     * Generates a drawable graph representation of the state graph using the graph's name
     * as title.
     * @param type the type of graph representation to generate
     * @return a diagram code of the state graph
     */
    public GraphRepresentation getGraph(GraphRepresentation.Type type) {
       String content = type.generator.generate(nodes, edges, name, true);

       return new GraphRepresentation(type, content);
    }

    /**
     * Container for nodes in the graph.
     */
    public static class Nodes {

       /**
        * The collection of nodes.
        */
       public final Set<Node> elements;

       /**
        * Instantiates a new Nodes container with the provided elements.
        * @param elements the elements to initialize
        */
       public Nodes(Collection<Node> elements) {
          this.elements = new LinkedHashSet<>(elements);
       }

       /**
        * Instantiates a new empty Nodes container.
        */
       public Nodes() {
          this.elements = new LinkedHashSet<>();
       }

       /**
        * Checks if any node matches the given identifier.
        * @param id the identifier to match
        * @return true if a matching node is found, false otherwise
        */
       public boolean anyMatchById(String id) {
          return elements.stream().anyMatch(n -> Objects.equals(n.id(), id));
       }

       /**
        * Returns a list of sub-state graph nodes.
        * @return a list of sub-state graph nodes
        */
       public List<SubStateGraphNode> onlySubStateGraphNodes() {
          return elements.stream()
             .filter(n -> n instanceof SubStateGraphNode)
             .map(n -> (SubStateGraphNode) n)
             .toList();
       }

       /**
        * Returns a list of nodes excluding sub-state graph nodes.
        * @return a list of nodes excluding sub-state graph nodes
        */
       public List<Node> exceptSubStateGraphNodes() {
          return elements.stream().filter(n -> !(n instanceof SubStateGraphNode)).toList();
       }

    }

    /**
     * Container for edges in the graph.
     */
    public static class Edges {

       /**
        * The collection of edges.
        */
       public final List<Edge> elements;

       /**
        * Instantiates a new Edges container with the provided elements.
        * @param elements the elements to initialize
        */
       public Edges(Collection<Edge> elements) {
          this.elements = new LinkedList<>(elements);
       }

       /**
        * Instantiates a new empty Edges container.
        */
       public Edges() {
          this.elements = new LinkedList<>();
       }

       /**
        * Retrieves the first edge matching the specified source identifier.
        * @param sourceId the source identifier to match
        * @return an optional containing the matched edge, or empty if none found
        */
       public Optional<Edge> edgeBySourceId(String sourceId) {
          return elements.stream().filter(e -> Objects.equals(e.sourceId(), sourceId)).findFirst();
       }

       /**
        * Retrieves a list of edges targeting the specified node identifier.
        * @param targetId the target identifier to match
        * @return a list of edges targeting the specified identifier
        */
       public List<Edge> edgesByTargetId(String targetId) {
          return elements.stream().filter(e -> e.anyMatchByTargetId(targetId)).toList();
       }

    }

}

CompileConfig

主要用于配置图的编译过程,包括检查点保存器、中断设置、生命周期监听器等

字段名称
字段类型
描述
saverConfig
SaverConfig
保存器配置,用于管理检查点保存器,替代旧的checkpointerSaver字段
lifecycleListeners
Deque
图生命周期监听器队列,用于监听节点执行事件
interruptsBefore
Set
在指定节点之前发生的中断点集合
interruptsAfter
Set
在指定节点之后发生的中断点集合
releaseThread
boolean
线程释放标志,指示是否在执行期间释放线程
observationRegistry
ObservationRegistry
观察注册表,用于监控和追踪

对外暴露方法

方法
描述
releaseThread
返回线程释放标志的当前状态
lifecycleListeners
获取不可变的节点生命周期监听器列表
observationRegistry
获取用于监控和追踪的观察注册表
interruptsBefore
返回在指定节点之前发生的中断点集合
interruptsAfter
返回在指定节点之后发生的中断点集合
checkpointSaver
丛保存期配置中检索默认检查点保存器
package com.alibaba.cloud.ai.graph;

import com.alibaba.cloud.ai.graph.checkpoint.BaseCheckpointSaver;
import com.alibaba.cloud.ai.graph.checkpoint.config.SaverConfig;
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;
import io.micrometer.observation.ObservationRegistry;

import java.util.Collection;
import java.util.Deque;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.stream.Collectors;

import static com.alibaba.cloud.ai.graph.checkpoint.constant.SaverConstant.MEMORY;
import static java.util.Optional.ofNullable;

/**
 * class is a configuration container for defining compile settings and behaviors. It
 * includes various fields and methods to manage checkpoint savers and interrupts,
 * providing both deprecated and current accessors.
 */
public class CompileConfig {

    private SaverConfig saverConfig = new SaverConfig().register(MEMORY, new MemorySaver());

    private Deque<GraphLifecycleListener> lifecycleListeners = new LinkedBlockingDeque<>(25);

    // private BaseCheckpointSaver checkpointSaver; // replaced with SaverConfig
    private Set<String> interruptsBefore = Set.of();

    private Set<String> interruptsAfter = Set.of();

    private boolean releaseThread = false;

    private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

    /**
     * Returns the current state of the thread release flag.
     *
     * @see BaseCheckpointSaver#release(RunnableConfig)
     * @return true if the thread has been released, false otherwise
     */
    public boolean releaseThread() {
       return releaseThread;
    }

    /**
     * Gets an unmodifiable list of node lifecycle listeners.
     * @return The list of lifecycle listeners.
     */
    public Queue<GraphLifecycleListener> lifecycleListeners() {
       return lifecycleListeners;
    }

    /**
     * Gets observation registry for monitoring and tracing.
     * @return The observation registry instance.
     */
    public ObservationRegistry observationRegistry() {
       return observationRegistry;
    }

    /**
     * Returns the array of interrupts that will occur before the specified node
     * (deprecated).
     * @return An array of interruptible nodes.
     * @deprecated Use {@link #interruptsBefore()} instead for better immutability and
     * type safety.
     */
    @Deprecated
    public String[] getInterruptBefore() {
       return interruptsBefore.toArray(new String[0]);
    }

    /**
     * Returns the array of interrupts that will occur after the specified node
     * (deprecated).
     * @return An array of interruptible nodes.
     * @deprecated Use {@link #interruptsAfter()} instead for better immutability and type
     * safety.
     */
    @Deprecated
    public String[] getInterruptAfter() {
       return interruptsAfter.toArray(new String[0]);
    }

    /**
     * Returns the set of interrupts that will occur before the specified node.
     * @return An unmodifiable set of interruptible nodes.
     */
    public Set<String> interruptsBefore() {
       return interruptsBefore;
    }

    /**
     * Returns the set of interrupts that will occur after the specified node.
     * @return An unmodifiable set of interruptible nodes.
     */
    public Set<String> interruptsAfter() {
       return interruptsAfter;
    }

    /**
     * Retrieves a checkpoint saver based on the specified type from the saver
     * configuration.
     * @param type The type of the checkpoint saver to retrieve.
     * @return An Optional containing the checkpoint saver if available; otherwise, empty.
     */
    public Optional<BaseCheckpointSaver> checkpointSaver(String type) {
       return ofNullable(saverConfig.get(type));
    }

    /**
     * Retrieves the default checkpoint saver from the saver configuration.
     * @return An Optional containing the default checkpoint saver if available;
     * otherwise, empty.
     */
    public Optional<BaseCheckpointSaver> checkpointSaver() {
       return ofNullable(saverConfig.get());
    }

    /**
     * Returns a new instance of the builder with default configuration settings.
     * @return A new Builder instance.
     */
    public static Builder builder() {
       return new Builder(new CompileConfig());
    }

    /**
     * Returns a new instance of the builder initialized with the provided configuration.
     * @param config The compile configuration to use as a base.
     * @return A new Builder instance initialized with the given configuration.
     */
    public static Builder builder(CompileConfig config) {
       return new Builder(config);
    }

    /**
     * Builder class for creating instances of CompileConfig. It allows setting various
     * options such as savers, interrupts, and lifecycle listeners in a fluent manner.
     */
    public static class Builder {

       private final CompileConfig config;

       /**
        * Initializes the builder with the provided compile configuration.
        * @param config The base configuration to start from.
        */
       protected Builder(CompileConfig config) {
          this.config = new CompileConfig(config);
       }

       /**
        * Sets whether the thread should be released during execution.
        * @param releaseThread Flag indicating whether to release the thread.
        * @see BaseCheckpointSaver#release(RunnableConfig)
        * @return This builder instance for method chaining.
        */
       public Builder releaseThread(boolean releaseThread) {
          this.config.releaseThread = releaseThread;
          return this;
       }

       /**
        * Sets the observation registry for monitoring and tracing.
        * @param observationRegistry The ObservationRegistry to use.
        * @return This builder instance for method chaining.
        */
       public Builder observationRegistry(ObservationRegistry observationRegistry) {
          this.config.observationRegistry = observationRegistry;
          return this;
       }

       /**
        * Sets the saver configuration for checkpoints.
        * @param saverConfig The SaverConfig to use.
        * @return This builder instance for method chaining.
        */
       public Builder saverConfig(SaverConfig saverConfig) {
          this.config.saverConfig = saverConfig;
          return this;
       }

       /**
        * Sets individual interrupt points that trigger before node execution using
        * varargs.
        * @param interruptBefore One or more strings representing interrupt points.
        * @return This builder instance for method chaining.
        */
       public Builder interruptBefore(String... interruptBefore) {
          this.config.interruptsBefore = Set.of(interruptBefore);
          return this;
       }

       /**
        * Sets individual interrupt points that trigger after node execution using
        * varargs.
        * @param interruptAfter One or more strings representing interrupt points.
        * @return This builder instance for method chaining.
        */
       public Builder interruptAfter(String... interruptAfter) {
          this.config.interruptsAfter = Set.of(interruptAfter);
          return this;
       }

       /**
        * Sets multiple interrupt points that trigger before node execution from a
        * collection.
        * @param interruptsBefore Collection of strings representing interrupt points.
        * @return This builder instance for method chaining.
        */
       public Builder interruptsBefore(Collection<String> interruptsBefore) {
          this.config.interruptsBefore = interruptsBefore.stream().collect(Collectors.toUnmodifiableSet());
          return this;
       }

       /**
        * Sets multiple interrupt points that trigger after node execution from a
        * collection.
        * @param interruptsAfter Collection of strings representing interrupt points.
        * @return This builder instance for method chaining.
        */
       public Builder interruptsAfter(Collection<String> interruptsAfter) {
          this.config.interruptsAfter = interruptsAfter.stream().collect(Collectors.toUnmodifiableSet());
          return this;
       }

       /**
        * Adds a lifecycle listener to monitor node execution events.
        * @param listener The NodeLifecycleListener to add.
        * @return This builder instance for method chaining.
        */
       public Builder withLifecycleListener(GraphLifecycleListener listener) {
          this.config.lifecycleListeners.offer(listener);
          return this;
       }

       /**
        * Finalizes the configuration and returns the compiled instance.
        * @return The configured CompileConfig object.
        */
       public CompileConfig build() {
          return config;
       }

    }

    /**
     * Default constructor used internally to create a new configuration with default
     * settings. Made private to ensure all instances are created through the builder
     * pattern.
     */
    private CompileConfig() {
    }

    /**
     * Copy constructor to create a new instance based on an existing configuration.
     * @param config The configuration to copy.
     */
    private CompileConfig(CompileConfig config) {
       this.saverConfig = config.saverConfig;
       this.interruptsBefore = config.interruptsBefore;
       this.interruptsAfter = config.interruptsAfter;
       this.releaseThread = config.releaseThread;
       this.lifecycleListeners = config.lifecycleListeners;
       this.observationRegistry = config.observationRegistry;
    }

}

CompiledGraph

图计算框架中的核心组件,代表一个已编译的图结

字段名称
字段类型
描述
stateGraph
StateGraph
关联原始状态图
keyStrategyMap
Map
键值策略对
nodes
Map
编译后的节点映射表
edges
Map
编译后的边映射表
processedData
ProcessedNodesEdgesAndConfig
处理节点、边的配置类
maxIterations
int
最大迭代次数,默认25次
compileConfig
CompileConfig
编译配置类
INTERRUPTAFTER
String
静态字符串,默认为"INTERRUPTED"

对外暴露的方法


方法名称
描述
图执行方法
stream
- 无参:创建默认的流式输出
- (Map inputs):基于输入创建流式输出
- (Map inputs, RunnableConfig config):基于输入和配置创建流式输出
streamFromInitialNode(OverAllState overAllState, RunnableConfig config)
从初始节点开始流式执行
invoke
- (Map inputs):基于输入执行图
- (OverAllState overAllState, RunnableConfig config):基于初始状态执行
- (Map inputs, RunnableConfig config):执行图并返回最终状态
updateState(RunnableConfig config, Map values, String asNode)
更新图状态
getStateHistory(RunnableConfig config)
获取状态历史
getState(RunnableConfig config)
获取特定配置的状态快照
stateOf(RunnableConfig config)
获取状态快照
setMaxIterations(int maxIterations)
设置最大迭代次数
图表示方法
getGraph
- (GraphRepresentation.Type type):生成默认图表示
- (GraphRepresentation.Type type, String title):生成带标题带图表示
- (GraphRepresentation.Type type, String title, boolean printConditionalEdges):生成指定类型的图表示









StreamMode 枚举类

  • VALUES:值流模式
  • SNAPSHOTS:快照流模式

AsyncNodeGenerator 内部类:负责图的异步执行和流式输出处理

核心字段说明:

  • Cursor cursor:游标对象,用于跟踪当前和下一个节点 ID

    • String currentNodeId:当前节点 ID
    • String nextNodeId:下一个节点 ID
    • String resumeFrom:恢复执行的节点 ID
  • int iteration:当前迭代次数

  • RunnableConfig config:运行时配置

  • boolean returnFromEmbed:标记是否从嵌入生成器返回

  • Map<String, Object> currentState:当前状态的 Map 表示

  • OverAllState overAllState:封装的 OverAllState 对象,提供更丰富的状态操作

核心方法
描述
next
负责执行图的下一步操作
evaluateAction
评估并执行节点操作
nextNodeId
根据当前节点和状态确认下一个节点
getEmbedGenerator
从部分状态中提取嵌入式生成器
processGeneratorOutput
处理生成器输出数据

Node

图结构节点核心类,用于定义图中的单个节点,包括节点的标识符和可选的操作工厂

字段名称
字段类型
描述
id
String
节点的唯一标识符
actionFactory
ActionFactory
用于创建节点执行的动作

对外暴露方法

方法名称
描述
isParallel
检查节点是否为并行节点,当前实现总为false
withIdUpdated
返回一个新节点实例,其ID经过指定函数转换
package com.alibaba.cloud.ai.graph.internal.node;

import com.alibaba.cloud.ai.graph.CompileConfig;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;

import java.util.Objects;
import java.util.function.Function;

import static java.lang.String.format;

/**
 * Represents a node in a graph, characterized by a unique identifier and a factory for
 * creating actions to be executed by the node. This is a generic record where the state
 * type is specified by the type parameter {@code State}.
 *
 * {@link OverAllState}.
 *
 */
public class Node {

    public interface ActionFactory {

       AsyncNodeActionWithConfig apply(CompileConfig config) throws GraphStateException;

    }

    private final String id;

    private final ActionFactory actionFactory;

    public Node(String id, ActionFactory actionFactory) {
       this.id = id;
       this.actionFactory = actionFactory;
    }

    /**
     * Constructor that accepts only the `id` and sets `actionFactory` to null.
     * @param id the unique identifier for the node
     */
    public Node(String id) {
       this(id, null);
    }

    /**
     * id
     * @return the unique identifier for the node.
     */
    public String id() {
       return id;
    }

    /**
     * actionFactory
     * @return a factory function that takes a {@link CompileConfig} and returns an
     * {@link AsyncNodeActionWithConfig} instance for the specified {@code State}.
     */
    public ActionFactory actionFactory() {
       return actionFactory;
    }

    public boolean isParallel() {
       // return id.startsWith(PARALLELPREFIX);
       return false;
    }

    public Node withIdUpdated(Function<String, String> newId) {
       return new Node(newId.apply(id), actionFactory);
    }

    /**
     * Checks if this node is equal to another object.
     * @param o the object to compare with
     * @return true if this node is equal to the specified object, false otherwise
     */
    @Override
    public boolean equals(Object o) {
       if (this == o)
          return true;
       if (o == null)
          return false;
       if (o instanceof Node node) {
          return Objects.equals(id, node.id);
       }
       return false;

    }

    /**
     * Returns the hash code value for this node.
     * @return the hash code value for this node
     */
    @Override
    public int hashCode() {
       return Objects.hash(id);
    }

    @Override
    public String toString() {
       return format("Node(%s,%s)", id, actionFactory != null ? "action" : "null");
    }

}

Edge

用于定义节点之间的连接关系,是一个 record 类

字段名称
字段类型
描述
sourceId
String
源节点ID,表示边的起始节点
targets
List
边的目标节点或条件

方法名称
描述
构造
Edge
- (String id):创建一个只有源节点的边
- (String sourceId, EdgeValue target):创建一个从源节点到单个目标节点的边
- (String sourceId, List targets):创建一个从源节点到多个目标节点的边
边的判断
isParallel
判断是否为并行边(目标节点数量大于1)
target
单个目标节点,如果为并行边则抛出异常
anyMatchByTargetId
检查目标节点中是否包含指定的目标节点ID
withSourceAndTargetIdsUpdated
更新源节点 ID 和目标节点值,返回新的 Edge 实例
validate
验证边的有效性
package com.alibaba.cloud.ai.graph.internal.edge;

import com.alibaba.cloud.ai.graph.exception.Errors;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.internal.node.Node;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import static com.alibaba.cloud.ai.graph.StateGraph.START;
import static java.lang.String.format;

/**
 * Represents an edge in a graph with a source ID and a target value.
 *
 * @param sourceId The ID of the source node.
 * @param targets The targets value associated with the edge.
 */
public record Edge(String sourceId, List<EdgeValue> targets) {

    public Edge(String sourceId, EdgeValue target) {
       this(sourceId, List.of(target));
    }

    public Edge(String id) {
       this(id, List.of());
    }

    public boolean isParallel() {
       return targets.size() > 1;
    }

    public EdgeValue target() {
       if (isParallel()) {
          throw new IllegalStateException(format("Edge '%s' is parallel", sourceId));
       }
       return targets.get(0);
    }

    public boolean anyMatchByTargetId(String targetId) {
       return targets().stream()
          .anyMatch(v -> (v.id() != null) ? Objects.equals(v.id(), targetId)
                : v.value().mappings().containsValue(targetId)

          );
    }

    public Edge withSourceAndTargetIdsUpdated(Node node, Function<String, String> newSourceId,
          Function<String, EdgeValue> newTarget) {

       var newTargets = targets().stream().map(t -> t.withTargetIdsUpdated(newTarget)).toList();
       return new Edge(newSourceId.apply(sourceId), newTargets);

    }

    public void validate(StateGraph.Nodes nodes) throws GraphStateException {
       if (!Objects.equals(sourceId(), START) && !nodes.anyMatchById(sourceId())) {
          throw Errors.missingNodeReferencedByEdge.exception(sourceId());
       }

       if (isParallel()) { // check for duplicates targets
          Set<String> duplicates = targets.stream()
             .collect(Collectors.groupingBy(EdgeValue::id, Collectors.counting())) // Group
                                                                   // by
                                                                   // element
                                                                   // and
                                                                   // count
                                                                   // occurrences
             .entrySet()
             .stream()
             .filter(entry -> entry.getValue() > 1) // Filter elements with more than
                                           // one occurrence
             .map(Map.Entry::getKey)
             .collect(Collectors.toSet());
          if (!duplicates.isEmpty()) {
             throw Errors.duplicateEdgeTargetError.exception(sourceId(), duplicates);
          }
       }

       for (EdgeValue target : targets) {
          validate(target, nodes);
       }

    }

    private void validate(EdgeValue target, StateGraph.Nodes nodes) throws GraphStateException {
       if (target.id() != null) {
          if (!Objects.equals(target.id(), StateGraph.END) && !nodes.anyMatchById(target.id())) {
             throw Errors.missingNodeReferencedByEdge.exception(target.id());
          }
       }
       else if (target.value() != null) {
          for (String nodeId : target.value().mappings().values()) {
             if (!Objects.equals(nodeId, StateGraph.END) && !nodes.anyMatchById(nodeId)) {
                throw Errors.missingNodeInEdgeMapping.exception(sourceId(), nodeId);
             }
          }
       }
       else {
          throw Errors.invalidEdgeTarget.exception(sourceId());
       }

    }

    /**
     * Checks if this edge is equal to another object.
     * @param o the object to compare with
     * @return true if this edge is equal to the specified object, false otherwise
     */
    @Override
    public boolean equals(Object o) {
       if (this == o)
          return true;
       if (o == null || getClass() != o.getClass())
          return false;
       Edge node = (Edge) o;
       return Objects.equals(sourceId, node.sourceId);
    }

    /**
     * Returns the hash code value for this edge.
     * @return the hash code value for this edge
     */
    @Override
    public int hashCode() {
       return Objects.hash(sourceId);
    }

}

EdgeValue

图边值的核心类,用于定义边的目标节点或条件,支持条件边

字段名称
字段类型
描述
id
String
目标节点的唯一标识符,用于固定边
value
EdgeCondition
与边值关联的条件,用于条件边

方法名称
描述
构造
EdgeValue
- (String id):创建一个只有目标节点 ID 的 EdgeValue 实例,条件为 null
- (EdgeCondition value):创建一个只有条件的EdgeValue实例,目标节点ID为null
- (String id, EdgeCondition value):创建一个包含目标节点 ID 和条件的 EdgeValue 实例
更新方法
withTargetIdsUpdated(Function target)
更新目标节点 ID,返回新的 EdgeValue 实例
- 如果当前 EdgeValue 有目标节点 ID,则应用目标函数
- 如果当前 EdgeValue 有条件,则更新条件中的映射关系
package com.alibaba.cloud.ai.graph.internal.edge;

import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
 * @param id The unique identifier for the edge value.
 * @param value The condition associated with the edge value.
 */
public record EdgeValue(String id, EdgeCondition value) {

    public EdgeValue(String id) {
       this(id, null);
    }

    public EdgeValue(EdgeCondition value) {
       this(null, value);
    }

    EdgeValue withTargetIdsUpdated(Function<String, EdgeValue> target) {
       if (id != null) {
          return target.apply(id);
       }

       var newMappings = value.mappings().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> {
          var v = target.apply(e.getValue());
          return (v.id() != null) ? v.id() : e.getValue();
       }));

       return new EdgeValue(null, new EdgeCondition(value.action(), newMappings));

    }

}

EdgeCondition

用于定义条件边的条和映射关系

字段名称
字段类型
描述
action
AsyncCommandAction
异步命令动作,用于执行条件判断逻辑
mappings
Map
条件到目标节点的映射关系表
package com.alibaba.cloud.ai.graph.internal.edge;

import com.alibaba.cloud.ai.graph.action.AsyncCommandAction;

import java.util.Map;

import static java.lang.String.format;

/**
 * Represents a condition associated with an edge in a graph.
 *
 * @param action The action to be performed asynchronously when the edge condition is met.
 * @param mappings A map of string key-value pairs representing additional mappings for
 * the edge condition.
 */
public record EdgeCondition(AsyncCommandAction action, Map<String, String> mappings) {

    @Override
    public String toString() {
       return format("EdgeCondition[ %s, mapping=%s", action != null ? "action" : "null", mappings);
    }

}

学习交流圈

你好,我是影子,曾先后在🐻、新能源、老铁就职,兼任Spring AI Alibaba开源社区的Committer。目前新建了一个交流群,一个人走得快,一群人走得远,关注公众号后可获得个人微信,添加微信后备注“交流”入群。另外,本人长期维护一套飞书云文档笔记,涵盖后端、大数据系统化的面试资料,可私信免费获取