原文链接地址:Spring Ai Alibaba Graph源码解读系列—核心启动类
[!TIP] 取自 Spring Ai Alibaba 1.0.0.3 版本
- 项目源代码:github.com/alibaba/spr…
核心类
OverAllState
图执行过程中的核心状态管理类,用于存储和管理图中各个节点间共享的数据
| 字段名称 | 字段类型 | 描述 |
| data | Map | 存储实际的状态数据 |
| keyStrategies | Map | 每个键对应的更新策略 |
| resume | Boolean | 标识状态是否用于恢复执行 |
| humanFeedback | HumanFeedback | 人类反馈消息 |
| interruptMessage | String | 存储中断消息 |
| DEFAULTINPUTKEY | String | 静态字符串常量“input”,表示输入键名 |
对外暴露的方法
| 方法名称 | 描述 | |
| 构造方法 | OverAllState | 支持四种方式构造 - 无参: 默认构造函数,初始化空数据和策略映射,并注册默认输入键 - Map data): 从现有数据初始化状态 - (boolean resume): 指定恢复模式的构造函数 - (Map data, Map keyStrategies, Boolean resume): 完整构造函数 |
| 数据访问方法 | data | 返回不可变的数据映射视图 |
| value | - (String key): 获取指定键的值(Optional包装) - (String key, Class type): 获取指定键的值并转换为指定类型 - (String key, T defaultValue): 获取指定键的值,如果不存在则返回默认值 | |
| keyStrategies | 获取键策略映射 | |
| 状态更新方法 | input | 处理输入数据并更新状态 |
| updateState | 使用已注册的策略更新状态 | |
| updateStateBySchema | 根据模式定义的策略更新状态 | |
| 恢复模式相关方法 | isResume | 检查是否处于恢复模式 |
| withResume | 设置为恢复模式 | |
| withoutResume | 取消恢复模式 | |
| copyWithResume | 创建恢复模式的副本 | |
| 人工反馈相关方法 | humanFeedback | 获取人工反馈 |
| withHumanFeedback | 设置人工反馈 | |
| interruptMessage | 获取中断消息 | |
| setInterruptMessage | 设置中断消息 |
内部静态类 HumanFeedback,处理和存储在工作流执行过程中来自人工反馈的信息
- Map<String, Object> data:存储人工反馈的具体数据内容,可以包含任意键值对形式的数据
- String nextNodeId:指定下一个要执行的节点 ID
package com.alibaba.cloud.ai.graph;
import org.springframework.util.CollectionUtils;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import static com.alibaba.cloud.ai.graph.utils.CollectionsUtils.entryOf;
import static java.util.Collections.unmodifiableMap;
import static java.util.Optional.ofNullable;
public final class OverAllState implements Serializable {
/**
* Internal map storing the actual state data. All get/set operations on state values
* go through this map.
*/
private final Map<String, Object> data;
/**
* Mapping of keys to their respective update strategies. Determines how values for
* each key should be merged or updated.
*/
private final Map<String, KeyStrategy> keyStrategies;
/**
* Indicates whether this state is being used to resume a previously interrupted
* execution. If true, certain initialization steps may be skipped.
*/
private Boolean resume;
/**
* Holds optional human feedback information provided during execution. May be null if
* no feedback was given.
*/
private HumanFeedback humanFeedback;
/**
* Optional message indicating that the execution was interrupted. If non-null,
* indicates that the graph should halt or handle the interruption.
*/
private String interruptMessage;
/**
* The default key used for standard input injection into the state. Typically used
* when initializing the state with user or external input.
*/
public static final String DEFAULTINPUTKEY = "input";
/**
* Reset.
*/
public void reset() {
this.data.clear();
}
/**
* Snap shot optional.
* @return the optional
*/
public Optional<OverAllState> snapShot() {
return Optional.of(new OverAllState(new HashMap<>(this.data), new HashMap<>(this.keyStrategies), this.resume));
}
/**
* Instantiates a new Over all state.
* @param resume the is resume
*/
public OverAllState(boolean resume) {
this.data = new HashMap<>();
this.keyStrategies = new HashMap<>();
this.resume = resume;
}
/**
* Instantiates a new Over all state.
* @param data the data
*/
public OverAllState(Map<String, Object> data) {
this.data = new HashMap<>(data);
this.keyStrategies = new HashMap<>();
this.resume = false;
}
/**
* Instantiates a new Over all state.
*/
public OverAllState() {
this.data = new HashMap<>();
this.keyStrategies = new HashMap<>();
this.registerKeyAndStrategy(OverAllState.DEFAULTINPUTKEY, new ReplaceStrategy());
this.resume = false;
}
/**
* Instantiates a new Over all state.
* @param data the data
* @param keyStrategies the key strategies
* @param resume the resume
*/
protected OverAllState(Map<String, Object> data, Map<String, KeyStrategy> keyStrategies, Boolean resume) {
this.data = data;
this.keyStrategies = keyStrategies;
this.registerKeyAndStrategy(OverAllState.DEFAULTINPUTKEY, new ReplaceStrategy());
this.resume = resume;
}
/**
* Interrupt message string.
* @return the string
*/
public String interruptMessage() {
return interruptMessage;
}
/**
* Sets interrupt message.
* @param interruptMessage the interrupt message
*/
public void setInterruptMessage(String interruptMessage) {
this.interruptMessage = interruptMessage;
}
/**
* With human feedback.
* @param humanFeedback the human feedback
*/
public void withHumanFeedback(HumanFeedback humanFeedback) {
this.humanFeedback = humanFeedback;
}
/**
* Human feedback human feedback.
* @return the human feedback
*/
public HumanFeedback humanFeedback() {
return this.humanFeedback;
}
/**
* Copy with resume over all state.
* @return the over all state
*/
public OverAllState copyWithResume() {
return new OverAllState(this.data, this.keyStrategies, true);
}
/**
* With resume.
*/
public void withResume() {
this.resume = true;
}
/**
* Without resume.
*/
public void withoutResume() {
this.resume = false;
}
/**
* Is resume boolean.
* @return the boolean
*/
public boolean isResume() {
return this.resume;
}
/**
* Clears all data in the current state, leaving key strategies, resume flag, and
* human feedback intact.
*/
public void clear() {
this.data.clear();
}
/**
* Replaces the current state's contents with the provided state.
* <p>
* This method effectively copies all data, key strategies, resume flag, and human
* feedback from the provided state to this state.
* @param overAllState the state to copy from
*/
public void cover(OverAllState overAllState) {
this.keyStrategies.clear();
this.keyStrategies.putAll(overAllState.keyStrategies());
this.data.clear();
this.data.putAll(overAllState.data());
this.resume = overAllState.resume;
this.humanFeedback = overAllState.humanFeedback;
}
/**
* Inputs over all state.
* @param input the input
* @return the over all state
*/
public OverAllState input(Map<String, Object> input) {
if (input == null) {
withResume();
return this;
}
if (CollectionUtils.isEmpty(input)) {
return this;
}
Map<String, KeyStrategy> keyStrategies = keyStrategies();
input.keySet().stream().filter(key -> keyStrategies.containsKey(key)).forEach(key -> {
this.data.put(key, keyStrategies.get(key).apply(value(key, null), input.get(key)));
});
return this;
}
/**
* Add key and strategy over all state.
* @param key the key
* @param strategy the strategy
* @return the over all state
*/
public OverAllState registerKeyAndStrategy(String key, KeyStrategy strategy) {
this.keyStrategies.put(key, strategy);
return this;
}
/**
* Register key and strategy over all state.
* @param keyStrategies the key strategies
* @return the over all state
*/
public OverAllState registerKeyAndStrategy(Map<String, KeyStrategy> keyStrategies) {
this.keyStrategies.putAll(keyStrategies);
return this;
}
/**
* Is contain strategy boolean.
* @param key the key
* @return the boolean
*/
public boolean containStrategy(String key) {
return this.keyStrategies.containsKey(key);
}
/**
* Update state map.
* @param partialState the partial state
* @return the map
*/
public Map<String, Object> updateState(Map<String, Object> partialState) {
Map<String, KeyStrategy> keyStrategies = keyStrategies();
partialState.keySet().stream().filter(key -> keyStrategies.containsKey(key)).forEach(key -> {
this.data.put(key, keyStrategies.get(key).apply(value(key, null), partialState.get(key)));
});
return data();
}
/**
* Updates the internal state based on a schema-defined strategy.
* <p>
* This method first validates the input state, then updates the partial state
* according to the provided key strategies. The updated state is formed by merging
* the original state and the modified partial state, removing any null values in the
* process. The resulting entries are then used to update the internal data map.
* @param state the base state to update; must not be null
* @param partialState the partial state containing updates; may be null or empty
* @param keyStrategies the mapping of keys to update strategies; used to transform
* values
*/
public void updateStateBySchema(Map<String, Object> state, Map<String, Object> partialState,
Map<String, KeyStrategy> keyStrategies) {
updateState(updateState(state, partialState, keyStrategies));
}
/**
* Key verify boolean.
* @return the boolean
*/
protected boolean keyVerify() {
return hasCommonKey(this.data, getKeyStrategies());
}
private Map<?, ?> getKeyStrategies() {
return this.keyStrategies;
}
private boolean hasCommonKey(Map<?, ?> map1, Map<?, ?> map2) {
Set<?> keys1 = map1.keySet();
for (Object key : map2.keySet()) {
if (keys1.contains(key)) {
return true;
}
}
return false;
}
/**
* Updates a state with the provided partial state. The merge function is used to
* merge the current state value with the new value.
* @param state the current state
* @param partialState the partial state to update from
* @return the updated state
* @throws NullPointerException if state is null
*/
public static Map<String, Object> updateState(Map<String, Object> state, Map<String, Object> partialState) {
Objects.requireNonNull(state, "state cannot be null");
if (partialState == null || partialState.isEmpty()) {
return state;
}
return Stream.concat(state.entrySet().stream(), partialState.entrySet().stream())
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, OverAllState::mergeFunction));
}
/**
* Update state map.
* @param state the state
* @param partialState the partial state
* @param keyStrategies the key strategies
* @return the map
*/
public static Map<String, Object> updateState(Map<String, Object> state, Map<String, Object> partialState,
Map<String, KeyStrategy> keyStrategies) {
Objects.requireNonNull(state, "state cannot be null");
if (partialState == null || partialState.isEmpty()) {
return state;
}
Map<String, Object> updatedPartialState = updatePartialStateFromSchema(state, partialState, keyStrategies);
return Stream.concat(state.entrySet().stream(), updatedPartialState.entrySet().stream())
.collect(toMapRemovingNulls(Map.Entry::getKey, Map.Entry::getValue, (currentValue, newValue) -> newValue));
}
/**
* Updates the partial state from a schema using channels.
* @param state The current state as a map of key-value pairs.
* @param partialState The partial state to be updated.
* @param keyStrategies A map of channel names to their implementations.
* @return An updated version of the partial state after applying the schema and
* channels.
*/
private static Map<String, Object> updatePartialStateFromSchema(Map<String, Object> state,
Map<String, Object> partialState, Map<String, KeyStrategy> keyStrategies) {
if (keyStrategies == null || keyStrategies.isEmpty()) {
return partialState;
}
return partialState.entrySet().stream().map(entry -> {
KeyStrategy channel = keyStrategies.get(entry.getKey());
if (channel != null) {
Object newValue = channel.apply(state.get(entry.getKey()), entry.getValue());
return entryOf(entry.getKey(), newValue);
}
return entry;
}).collect(toMapAllowingNulls(Map.Entry::getKey, Map.Entry::getValue));
}
private static <T, K, U> Collector<T, ?, Map<K, U>> toMapRemovingNulls(Function<? super T, ? extends K> keyMapper,
Function<? super T, ? extends U> valueMapper, BinaryOperator<U> mergeFunction) {
return Collector.of(HashMap::new, (map, element) -> {
K key = keyMapper.apply(element);
U value = valueMapper.apply(element);
if (value == null) {
map.remove(key);
}
else {
map.merge(key, value, mergeFunction);
}
}, (map1, map2) -> {
map2.forEach((key, value) -> {
if (value != null) {
map1.merge(key, value, mergeFunction);
}
});
return map1;
}, Collector.Characteristics.UNORDERED);
}
private static <T, K, U> Collector<T, ?, Map<K, U>> toMapAllowingNulls(Function<? super T, ? extends K> keyMapper,
Function<? super T, ? extends U> valueMapper) {
return Collector.of(HashMap::new,
(map, element) -> map.put(keyMapper.apply(element), valueMapper.apply(element)), (map1, map2) -> {
map1.putAll(map2);
return map1;
}, Collector.Characteristics.UNORDERED);
}
/**
* Merges the current value with the new value using the appropriate merge function.
* @param currentValue the current value
* @param newValue the new value
* @return the merged value
*/
private static Object mergeFunction(Object currentValue, Object newValue) {
return newValue;
}
/**
* Key strategies map.
* @return the map
*/
public Map<String, KeyStrategy> keyStrategies() {
return keyStrategies;
}
/**
* Data map.
* @return the map
*/
public final Map<String, Object> data() {
return unmodifiableMap(data);
}
/**
* Value optional.
* @param <T> the type parameter
* @param key the key
* @return the optional
*/
public final <T> Optional<T> value(String key) {
return ofNullable((T) data().get(key));
}
/**
* Value optional.
* @param <T> the type parameter
* @param key the key
* @param type the type
* @return the optional
*/
public final <T> Optional<T> value(String key, Class<T> type) {
if (type != null) {
return ofNullable(type.cast(data().get(key)));
}
return value(key);
}
/**
* Value t.
* @param <T> the type parameter
* @param key the key
* @param defaultValue the default value
* @return the t
*/
public final <T> T value(String key, T defaultValue) {
return (T) value(key).orElse(defaultValue);
}
/**
* The type Human feedback.
*/
public static class HumanFeedback implements Serializable {
private Map<String, Object> data;
private String nextNodeId;
private String currentNodeId;
/**
* Instantiates a new Human feedback.
* @param data the data
* @param nextNodeId the next node id
*/
public HumanFeedback(Map<String, Object> data, String nextNodeId) {
this.data = data;
this.nextNodeId = nextNodeId;
}
/**
* Data map.
* @return the map
*/
public Map<String, Object> data() {
return data;
}
/**
* Next node id string.
* @return the string
*/
public String nextNodeId() {
return nextNodeId;
}
/**
* Sets data.
* @param data the data
*/
public void setData(Map<String, Object> data) {
this.data = data;
}
/**
* Sets next node id.
* @param nextNodeId the next node id
*/
public void setNextNodeId(String nextNodeId) {
this.nextNodeId = nextNodeId;
}
}
@Override
public String toString() {
return "OverAllState{" + "data=" + data + ", resume=" + resume + ", humanFeedback=" + humanFeedback
+ ", interruptMessage='" + interruptMessage + '\'' + '}';
}
}
RunnableConfig
运行配置类
| 字段名称 | 字段类型 | 描述 |
| threadId | String | 线程ID |
| checkPointId | String | 检查点ID |
| nextNode | String | 下一个要执行的节点ID |
| streamMode | CompiledGraph.StreamMode | 编译图的流模式,详情可见CompiledGraph类说明 |
| metadata | Map | 存储自定义元数据 |
| interruptedNodes | Map | 存储被中断的节点信息 |
package com.alibaba.cloud.ai.graph;
import com.alibaba.cloud.ai.graph.internal.node.ParallelNode;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.Optional.ofNullable;
/**
* A final class representing configuration for a runnable task. This class holds various
* parameters such as thread ID, checkpoint ID, next node, and stream mode, providing
* methods to modify these parameters safely without permanently altering the original
* configuration.
*/
public final class RunnableConfig implements HasMetadata<RunnableConfig.Builder> {
private final String threadId;
private final String checkPointId;
private final String nextNode;
private final CompiledGraph.StreamMode streamMode;
private final Map<String, Object> metadata;
private final Map<String, Object> interruptedNodes;
/**
* Returns the stream mode of the compiled graph.
* @return {@code StreamMode} representing the current stream mode.
*/
public CompiledGraph.StreamMode streamMode() {
return streamMode;
}
/**
* Returns the thread ID as an {@link Optional}.
* @return the thread ID wrapped in an {@code Optional}, or an empty {@code Optional}
* if no thread ID is set.
*/
public Optional<String> threadId() {
return ofNullable(threadId);
}
/**
* Returns the current {@code checkPointId} wrapped in an {@link Optional}.
* @return an {@link Optional} containing the {@code checkPointId}, or
* {@link Optional#empty()} if it is null.
*/
public Optional<String> checkPointId() {
return ofNullable(checkPointId);
}
/**
* Returns an {@code Optional} describing the next node in the sequence, or an empty
* {@code Optional} if there is no such node.
* @return an {@code Optional} describing the next node, or an empty {@code Optional}
*/
public Optional<String> nextNode() {
return ofNullable(nextNode);
}
/**
* Checks if a node is marked as interrupted in the metadata.
* @param nodeId the ID of the node to check for interruption status
* @return true if the node is marked as interrupted, false otherwise
*/
public boolean isInterrupted(String nodeId) {
return interruptData(HasMetadata.formatNodeId(nodeId)).map(value -> Boolean.TRUE.equals(value)).orElse(false);
}
/**
* Marks a node as not interrupted by setting its value to false in the metadata.
* @param nodeId the ID of the node to mark as not interrupted
* @return a new {@code RunnableConfig} instance with the updated metadata
*/
public void withNodeResumed(String nodeId) {
String formattedNodeId = HasMetadata.formatNodeId(nodeId);
interruptedNodes.put(formattedNodeId, false);
}
/**
* Removes the interrupted marker for a specific node by removing its entry from the
* metadata.
* @param nodeId the ID of the node to remove the interrupted marker for
* @return a new {@code RunnableConfig} instance with the updated metadata
*/
public void removeInterrupted(String nodeId) {
String formattedNodeId = HasMetadata.formatNodeId(nodeId);
if (interruptedNodes == null || !interruptedNodes.containsKey(formattedNodeId)) {
return; // No change needed if the marker doesn't exist
}
interruptedNodes.remove(formattedNodeId);
}
/**
* Marks a node as interrupted by adding it to the metadata with a formatted key. The
* node ID is formatted using {@link #formatNodeId(String)} and associated with a
* value of {@code true} in the metadata map.
* @param nodeId the ID of the node to mark as interrupted; must not be null
* @return this {@code Builder} instance for method chaining
* @throws NullPointerException if nodeId is null
*/
public void markNodeAsInterrupted(String nodeId) {
interruptedNodes.put(HasMetadata.formatNodeId(nodeId), true);
}
/**
* Create a new RunnableConfig with the same attributes as this one but with a
* different {@link CompiledGraph.StreamMode}.
* @param streamMode the new stream mode
* @return a new RunnableConfig with the updated stream mode
*/
public RunnableConfig withStreamMode(CompiledGraph.StreamMode streamMode) {
if (this.streamMode == streamMode) {
return this;
}
return RunnableConfig.builder(this).streamMode(streamMode).build();
}
/**
* Updates the checkpoint ID of the configuration.
* @param checkPointId The new checkpoint ID to set.
* @return A new instance of {@code RunnableConfig} with the updated checkpoint ID, or
* the current instance if no change is needed.
*/
public RunnableConfig withCheckPointId(String checkPointId) {
if (Objects.equals(this.checkPointId, checkPointId)) {
return this;
}
return RunnableConfig.builder(this).checkPointId(checkPointId).build();
}
/**
* Retrieves interrupt data associated with the specified key.
* @param key the key for which to retrieve interrupt data; may be null
* @return an Optional containing the interrupt data if present, or an empty Optional
* if the key is null or no data is found
*/
public Optional<Object> interruptData(String key) {
if (key == null) {
return Optional.empty();
}
return ofNullable(interruptedNodes).map(m -> m.get(key));
}
/**
* return metadata value for key
* @param key given metadata key
* @return metadata value for key if any
*/
@Override
public Optional<Object> metadata(String key) {
if (key == null) {
return Optional.empty();
}
return ofNullable(metadata).map(m -> m.get(key));
}
/**
* Creates a new instance of the {@link Builder} class.
* @return A new {@code Builder} object.
*/
public static Builder builder() {
return new Builder();
}
/**
* Creates a new {@code Builder} instance with the specified {@link RunnableConfig}.
* @param config The configuration for the {@code Builder}.
* @return A new {@code Builder} instance.
*/
public static Builder builder(RunnableConfig config) {
return new Builder(config);
}
/**
* A builder pattern class for constructing {@link RunnableConfig} objects. This class
* provides a fluent interface to set various properties of a {@link RunnableConfig}
* object and then build the final configuration.
*/
public static class Builder extends HasMetadata.Builder<Builder> {
private String threadId;
private String checkPointId;
private String nextNode;
private CompiledGraph.StreamMode streamMode = CompiledGraph.StreamMode.VALUES;
/**
* Constructs a new instance of the {@link Builder} with default configuration
* settings. Initializes a new {@link RunnableConfig} object for configuration
* purposes.
*/
Builder() {
}
/**
* Initializes a new instance of the {@code Builder} class with the specified
* {@link RunnableConfig}.
* @param config The configuration to be used for initialization.
*/
Builder(RunnableConfig config) {
super(requireNonNull(config, "config cannot be null!").metadata);
this.threadId = config.threadId;
this.checkPointId = config.checkPointId;
this.nextNode = config.nextNode;
this.streamMode = config.streamMode;
}
/**
* Sets the ID of the thread.
* @param threadId the ID of the thread to set
* @return a reference to this {@code Builder} object so that method calls can be
* chained together
*/
public Builder threadId(String threadId) {
this.threadId = threadId;
return this;
}
/**
* Sets the checkpoint ID for the configuration.
* @param {@code checkPointId} - the ID of the checkpoint to be set
* @return {@literal this} - a reference to the current `Builder` instance
*/
public Builder checkPointId(String checkPointId) {
this.checkPointId = checkPointId;
return this;
}
/**
* Sets the next node in the configuration and returns this builder for method
* chaining.
* @param nextNode The next node to be set.
* @return This builder instance, allowing for method chaining.
*/
public Builder nextNode(String nextNode) {
this.nextNode = nextNode;
return this;
}
/**
* Sets the stream mode of the configuration.
* @param streamMode The {@link CompiledGraph.StreamMode} to set.
* @return A reference to this builder for method chaining.
*/
public Builder streamMode(CompiledGraph.StreamMode streamMode) {
this.streamMode = streamMode;
return this;
}
/**
* Adds a custom {@link Executor} for a specific parallel node.
* <p>
* This allows you to control the execution of branches within a parallel node.
* When a parallel node is executed, it will look for an executor in the
* {@link RunnableConfig} metadata. If found, it will be used to run the parallel
* branches concurrently.
* @param nodeId the ID of the parallel node.
* @param executor the {@link Executor} to use for the parallel node.
* @return this {@code Builder} instance for method chaining.
*/
public Builder addParallelNodeExecutor(String nodeId, Executor executor) {
return addMetadata(ParallelNode.formatNodeId(nodeId), requireNonNull(executor, "executor cannot be null!"));
}
/**
* Constructs and returns the configured {@code RunnableConfig} object.
* @return the configured {@code RunnableConfig} object
*/
public RunnableConfig build() {
return new RunnableConfig(this);
}
}
/**
* Creates a new instance of {@code RunnableConfig} as a copy of the provided
* {@code config}.
* @param builder The configuration builder.
*/
private RunnableConfig(Builder builder) {
this.threadId = builder.threadId;
this.checkPointId = builder.checkPointId;
this.nextNode = builder.nextNode;
this.streamMode = builder.streamMode;
this.metadata = ofNullable(builder.metadata()).map(Map::copyOf).orElse(null);
this.interruptedNodes = new ConcurrentHashMap<>();
}
@Override
public String toString() {
return format("RunnableConfig{ threadId=%s, checkPointId=%s, nextNode=%s, streamMode=%s }", threadId,
checkPointId, nextNode, streamMode);
}
}
StateGraph
用于表示和构建基于状态的图结构工作流,有如下功能
- 图结构定义:提供构建有向图的 API,包括节点和边的定义
- 状态管理:与 OverAllState 配合管理图执行过程中的状态
- 工作流编排:支持定义复杂的工作流执行逻辑,包括条件边和子图
| 字段名称 | 字段类型 | 描述 |
| START | String | 图的起始节点标识常量("START") |
| END | String | 图的结束节点标识常量("END") |
| ERROR | String | 图的错误节点标识常量("ERROR") |
| NODEBEFORE | String | 节点执行前钩子标识常量("NODEBEFORE") |
| NODEAFTER | String | 节点执行后钩子标识常量("NODEAFTER") |
| nodes | Nodes | 图中所有节点的容器 |
| edges | Edges | 图中所有边的容器 |
| overAllStateFactory | OverAllStateFactory | 创建整体状态实例的工厂(已废弃) |
| keyStrategyFactory | KeyStrategyFactory | 供键策略的工厂 |
| name | String | 图的名称 |
| stateSerializer | PlainTextStateSerializer | 内部类,基于Jackson的状态序列化器 |
对外暴露的方法
| 方法名称 | 描述 | |
| 构造 | StateGraph | 支持五种方式构造 - 无参 - KeyStrategyFactory keyStrategyFactory:带键策略工厂 - (String name, KeyStrategyFactory keyStrategyFactory):带名称和键策略工厂 - (KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer):带键策略工厂、序列化器 - (String name, KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) :带键策略工厂、名称、序列化器 |
| 节点管理 | addNode | - (String id, Node node):添加节点实例 - (String id, AsyncNodeAction action):添加普通节点 - (String id, AsyncNodeActionWithConfig actionWithConfig):添加带配置的节点 - (String id, AsyncCommandAction action, Map mappings):添加条件边节点 - (String id, StateGraph subGraph):添加子图节点 - (String id, CompiledGraph subGraph):添加已编译的子图节点 |
| 边管理 | addEdge | (String sourceId, String targetId):添加普通边 |
| addConditionalEdges | - (String sourceId, AsyncEdgeAction condition, Map mappings):添加条件边,使用EdgeAction - (String sourceId, AsyncCommandAction condition, Map mappings):添加条件边 | |
| 编译和验证 | validateGraph | 验证图结构 |
| compile | - 无参:使用默认配置编译图 - (CompileConfig config):由编译配置编译图 | |
| 可视化方法 | getGraph | - (GraphRepresentation.Type type, String title, boolean printConditionalEdges):生成图表示 - (GraphRepresentation.Type type, String title):生成图表示,含条件边 - (GraphRepresentation.Type type):使用图表示,图名称作为标题 |
静态内部类 Nodes,管理图中的节点集合
- anyMatchById(String id):检查是否存在指定 ID 的节点
- onlySubStateGraphNodes():获取所有子图节点
- exceptSubStateGraphNodes():获取移除子图之外的所有节点
静态内部类 Edges,管理图中边集合
- edgeBySourceId(String sourceId):根据源节点 ID 查找边
- edgesByTargetId(String targetId):根据目标节点 ID 查找边
package com.alibaba.cloud.ai.graph;
import com.alibaba.cloud.ai.graph.action.AsyncCommandAction;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import com.alibaba.cloud.ai.graph.checkpoint.config.SaverConfig;
import com.alibaba.cloud.ai.graph.checkpoint.constant.SaverConstant;
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;
import com.alibaba.cloud.ai.graph.exception.Errors;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.internal.edge.Edge;
import com.alibaba.cloud.ai.graph.internal.edge.EdgeCondition;
import com.alibaba.cloud.ai.graph.internal.edge.EdgeValue;
import com.alibaba.cloud.ai.graph.internal.node.CommandNode;
import com.alibaba.cloud.ai.graph.internal.node.Node;
import com.alibaba.cloud.ai.graph.internal.node.SubCompiledGraphNode;
import com.alibaba.cloud.ai.graph.internal.node.SubStateGraphNode;
import com.alibaba.cloud.ai.graph.serializer.StateSerializer;
import com.alibaba.cloud.ai.graph.serializer.plaintext.PlainTextStateSerializer;
import com.alibaba.cloud.ai.graph.serializer.plaintext.jackson.JacksonStateSerializer;
import com.alibaba.cloud.ai.graph.state.AgentStateFactory;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.LinkedHashSet;
/**
* Represents a state graph with nodes and edges.
*/
public class StateGraph {
/**
* Constant representing the END of the graph.
*/
public static final String END = "END";
/**
* Constant representing the START of the graph.
*/
public static final String START = "START";
/**
* Constant representing the ERROR of the graph.
*/
public static final String ERROR = "ERROR";
/**
* Constant representing the NODEBEFORE of the graph.
*/
public static final String NODEBEFORE = "NODEBEFORE";
/**
* Constant representing the NODEAFTER of the graph.
*/
public static final String NODEAFTER = "NODEAFTER";
/**
* Collection of nodes in the graph.
*/
final Nodes nodes = new Nodes();
/**
* Collection of edges in the graph.
*/
final Edges edges = new Edges();
/**
* Factory for creating overall state instances.
*/
private OverAllStateFactory overAllStateFactory;
/**
* Factory for providing key strategies.
*/
private KeyStrategyFactory keyStrategyFactory;
/**
* Name of the graph.
*/
private String name;
/**
* Serializer for the state.
*/
private final PlainTextStateSerializer stateSerializer;
/**
* Jackson-based serializer for state.
*/
static class JacksonSerializer extends JacksonStateSerializer {
/**
* Instantiates a new Jackson serializer.
*/
public JacksonSerializer() {
super(OverAllState::new);
}
/**
* Gets object mapper.
* @return the object mapper
*/
ObjectMapper getObjectMapper() {
return objectMapper;
}
}
/**
* Constructs a StateGraph with the specified name, key strategy factory, and state
* serializer.
* @param name the name of the graph
* @param keyStrategyFactory the factory for providing key strategies
* @param stateSerializer the plain text state serializer to use
*/
public StateGraph(String name, KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) {
this.name = name;
this.keyStrategyFactory = keyStrategyFactory;
this.stateSerializer = stateSerializer;
}
public StateGraph(KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) {
this.keyStrategyFactory = keyStrategyFactory;
this.stateSerializer = stateSerializer;
}
/**
* Constructs a StateGraph with the given key strategy factory and name.
* @param keyStrategyFactory the factory for providing key strategies
* @param name the name of the graph
*/
public StateGraph(String name, KeyStrategyFactory keyStrategyFactory) {
this.keyStrategyFactory = keyStrategyFactory;
this.name = name;
this.stateSerializer = new JacksonSerializer();
}
/**
* Constructs a StateGraph with the provided key strategy factory.
* @param keyStrategyFactory the factory for providing key strategies
*/
public StateGraph(KeyStrategyFactory keyStrategyFactory) {
this.keyStrategyFactory = keyStrategyFactory;
this.stateSerializer = new JacksonSerializer();
}
/**
* Deprecated constructor that initializes a StateGraph with the specified name,
* overall state factory, and state serializer.
* @param name the name of the graph
* @param overAllStateFactory the factory for creating overall state instances
* @param plainTextStateSerializer the plain text state serializer to use
*/
@Deprecated
public StateGraph(String name, OverAllStateFactory overAllStateFactory,
PlainTextStateSerializer plainTextStateSerializer) {
this.name = name;
this.overAllStateFactory = overAllStateFactory;
this.stateSerializer = plainTextStateSerializer;
}
/**
* Deprecated constructor that initializes a StateGraph with the specified name and
* overall state factory.
* @param name the name of the graph
* @param overAllStateFactory the factory for creating overall state instances
*/
@Deprecated
public StateGraph(String name, OverAllStateFactory overAllStateFactory) {
this.name = name;
this.overAllStateFactory = overAllStateFactory;
this.stateSerializer = new JacksonSerializer();
}
/**
* Deprecated constructor that initializes a StateGraph with the provided overall
* state factory.
* @param overAllStateFactory the factory for creating overall state instances
*/
@Deprecated
public StateGraph(OverAllStateFactory overAllStateFactory) {
this.overAllStateFactory = overAllStateFactory;
this.stateSerializer = new JacksonSerializer();
}
/**
* Deprecated constructor that initializes a StateGraph with the provided overall
* state factory and state serializer.
* @param overAllStateFactory the factory for creating overall state instances
* @param plainTextStateSerializer the plain text state serializer to use
*/
@Deprecated
public StateGraph(OverAllStateFactory overAllStateFactory, PlainTextStateSerializer plainTextStateSerializer) {
this.overAllStateFactory = overAllStateFactory;
this.stateSerializer = plainTextStateSerializer;
}
/**
* Default constructor that initializes a StateGraph with a Gson-based state
* serializer.
*/
public StateGraph() {
this.stateSerializer = new JacksonSerializer();
this.keyStrategyFactory = HashMap::new;
}
/**
* Gets the name of the graph.
* @return the name
*/
public String getName() {
return name;
}
/**
* Gets the state serializer used by this graph.
* @return the state serializer
*/
public StateSerializer<OverAllState> getStateSerializer() {
return stateSerializer;
}
/**
* Gets the state factory associated with this graph's state serializer.
* @return the state factory
*/
public final AgentStateFactory<OverAllState> getStateFactory() {
return stateSerializer.stateFactory();
}
/**
* Gets the overall state factory.
* @return the overall state factory
*/
@Deprecated
public final OverAllStateFactory getOverAllStateFactory() {
return overAllStateFactory;
}
/**
* Gets the key strategy factory.
* @return the key strategy factory
*/
public final KeyStrategyFactory getKeyStrategyFactory() {
return keyStrategyFactory;
}
/**
* Adds a commandNode to the graph.
* @param id the identifier of the node
* @param action AsyncCommandAction action
* @param mappings the mappings to be used for conditional edges
* @return this state graph instance
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, AsyncCommandAction action, Map<String, String> mappings)
throws GraphStateException {
return addNode(id, new CommandNode(id, action, mappings));
}
/**
* Adds a node to the graph.
* @param id the identifier of the node
* @param action the asynchronous node action to be performed by the node
* @return this state graph instance
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, AsyncNodeAction action) throws GraphStateException {
return addNode(id, AsyncNodeActionWithConfig.of(action));
}
/**
* Adds a node to the graph with the specified action and configuration.
* @param id the identifier of the node
* @param actionWithConfig the action to be performed by the node
* @return this state graph instance
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, AsyncNodeActionWithConfig actionWithConfig) throws GraphStateException {
Node node = new Node(id, (config) -> actionWithConfig);
return addNode(id, node);
}
/**
* Adds a node to the graph with the specified identifier and node instance.
* @param id the identifier of the node
* @param node the node to be added
* @return this state graph instance
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, Node node) throws GraphStateException {
if (Objects.equals(node.id(), END)) {
throw Errors.invalidNodeIdentifier.exception(END);
}
if (!Objects.equals(node.id(), id)) {
throw Errors.invalidNodeIdentifier.exception(node.id(), id);
}
if (nodes.elements.contains(node)) {
throw Errors.duplicateNodeError.exception(id);
}
nodes.elements.add(node);
return this;
}
/**
* Adds a subgraph to the state graph by creating a node with the specified
* identifier. This implies that the subgraph shares the same state with the parent
* graph.
* @param id the identifier of the node representing the subgraph
* @param subGraph the compiled subgraph to be added
* @return this state graph instance
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, CompiledGraph subGraph) throws GraphStateException {
if (Objects.equals(id, END)) {
throw Errors.invalidNodeIdentifier.exception(END);
}
var node = new SubCompiledGraphNode(id, subGraph);
if (nodes.elements.contains(node)) {
throw Errors.duplicateNodeError.exception(id);
}
nodes.elements.add(node);
return this;
}
/**
* Adds a subgraph to the state graph by creating a node with the specified
* identifier. This implies that the subgraph will share the same state with the
* parent graph and will be compiled when the parent is compiled.
* @param id the identifier of the node representing the subgraph
* @param subGraph the subgraph to be added; it will be compiled during compilation of
* the parent
* @return this state graph instance
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, StateGraph subGraph) throws GraphStateException {
if (Objects.equals(id, END)) {
throw Errors.invalidNodeIdentifier.exception(END);
}
subGraph.validateGraph();
var node = new SubStateGraphNode(id, subGraph);
if (nodes.elements.contains(node)) {
throw Errors.duplicateNodeError.exception(id);
}
nodes.elements.add(node);
return this;
}
/**
* Adds an edge to the graph between the specified source and target nodes.
* @param sourceId the identifier of the source node
* @param targetId the identifier of the target node
* @return this state graph instance
* @throws GraphStateException if the edge identifier is invalid or the edge already
* exists
*/
public StateGraph addEdge(String sourceId, String targetId) throws GraphStateException {
if (Objects.equals(sourceId, END)) {
throw Errors.invalidEdgeIdentifier.exception(END);
}
var newEdge = new Edge(sourceId, new EdgeValue(targetId));
int index = edges.elements.indexOf(newEdge);
if (index >= 0) {
var newTargets = new ArrayList<>(edges.elements.get(index).targets());
newTargets.add(newEdge.target());
edges.elements.set(index, new Edge(sourceId, newTargets));
}
else {
edges.elements.add(newEdge);
}
return this;
}
/**
* Adds conditional edges to the graph based on the provided condition and mappings.
* @param sourceId the identifier of the source node
* @param condition the command action used to determine the target node
* @param mappings the mappings of conditions to target nodes
* @return this state graph instance
* @throws GraphStateException if the edge identifier is invalid, the mappings are
* empty, or the edge already exists
*/
public StateGraph addConditionalEdges(String sourceId, AsyncCommandAction condition, Map<String, String> mappings)
throws GraphStateException {
if (Objects.equals(sourceId, END)) {
throw Errors.invalidEdgeIdentifier.exception(END);
}
if (mappings == null || mappings.isEmpty()) {
throw Errors.edgeMappingIsEmpty.exception(sourceId);
}
var newEdge = new Edge(sourceId, new EdgeValue(new EdgeCondition(condition, mappings)));
if (edges.elements.contains(newEdge)) {
throw Errors.duplicateConditionalEdgeError.exception(sourceId);
}
else {
edges.elements.add(newEdge);
}
return this;
}
/**
* Adds conditional edges to the graph based on the provided edge action condition and
* mappings.
* @param sourceId the identifier of the source node
* @param condition the edge action used to determine the target node
* @param mappings the mappings of conditions to target nodes
* @return this state graph instance
* @throws GraphStateException if the edge identifier is invalid, the mappings are
* empty, or the edge already exists
*/
public StateGraph addConditionalEdges(String sourceId, AsyncEdgeAction condition, Map<String, String> mappings)
throws GraphStateException {
return addConditionalEdges(sourceId, AsyncCommandAction.of(condition), mappings);
}
/**
* Validates the structure of the graph ensuring all connections are valid.
* @throws GraphStateException if there are errors related to the graph state
*/
void validateGraph() throws GraphStateException {
var edgeStart = edges.edgeBySourceId(START).orElseThrow(Errors.missingEntryPoint::exception);
edgeStart.validate(nodes);
validateNode(nodes);
for (Edge edge : edges.elements) {
edge.validate(nodes);
}
}
private void validateNode(Nodes nodes) throws GraphStateException {
List<CommandNode> commandNodeList = nodes.elements.stream().filter(node -> {
return node instanceof CommandNode commandNode;
}).map(node -> (CommandNode) node).toList();
for (CommandNode commandNode : commandNodeList) {
for (String key : commandNode.getMappings().keySet()) {
if (!nodes.anyMatchById(key)) {
throw Errors.missingNodeInEdgeMapping.exception(commandNode.id(), key);
}
}
}
}
/**
* Compiles the state graph into a compiled graph using the provided configuration.
* @param config the compile configuration
* @return a compiled graph
* @throws GraphStateException if there are errors related to the graph state
*/
public CompiledGraph compile(CompileConfig config) throws GraphStateException {
Objects.requireNonNull(config, "config cannot be null");
validateGraph();
return new CompiledGraph(this, config);
}
/**
* Compiles the state graph into a compiled graph using a default configuration with
* memory saver.
* @return a compiled graph
* @throws GraphStateException if there are errors related to the graph state
*/
public CompiledGraph compile() throws GraphStateException {
SaverConfig saverConfig = SaverConfig.builder().register(SaverConstant.MEMORY, new MemorySaver()).build();
return compile(CompileConfig.builder().saverConfig(saverConfig).build());
}
/**
* Generates a drawable graph representation of the state graph.
* @param type the type of graph representation to generate
* @param title the title of the graph
* @param printConditionalEdges whether to include conditional edges in the output
* @return a diagram code of the state graph
*/
public GraphRepresentation getGraph(GraphRepresentation.Type type, String title, boolean printConditionalEdges) {
String content = type.generator.generate(nodes, edges, title, printConditionalEdges);
return new GraphRepresentation(type, content);
}
/**
* Generates a drawable graph representation of the state graph with conditional edges
* included.
* @param type the type of graph representation to generate
* @param title the title of the graph
* @return a diagram code of the state graph
*/
public GraphRepresentation getGraph(GraphRepresentation.Type type, String title) {
String content = type.generator.generate(nodes, edges, title, true);
return new GraphRepresentation(type, content);
}
/**
* Generates a drawable graph representation of the state graph using the graph's name
* as title.
* @param type the type of graph representation to generate
* @return a diagram code of the state graph
*/
public GraphRepresentation getGraph(GraphRepresentation.Type type) {
String content = type.generator.generate(nodes, edges, name, true);
return new GraphRepresentation(type, content);
}
/**
* Container for nodes in the graph.
*/
public static class Nodes {
/**
* The collection of nodes.
*/
public final Set<Node> elements;
/**
* Instantiates a new Nodes container with the provided elements.
* @param elements the elements to initialize
*/
public Nodes(Collection<Node> elements) {
this.elements = new LinkedHashSet<>(elements);
}
/**
* Instantiates a new empty Nodes container.
*/
public Nodes() {
this.elements = new LinkedHashSet<>();
}
/**
* Checks if any node matches the given identifier.
* @param id the identifier to match
* @return true if a matching node is found, false otherwise
*/
public boolean anyMatchById(String id) {
return elements.stream().anyMatch(n -> Objects.equals(n.id(), id));
}
/**
* Returns a list of sub-state graph nodes.
* @return a list of sub-state graph nodes
*/
public List<SubStateGraphNode> onlySubStateGraphNodes() {
return elements.stream()
.filter(n -> n instanceof SubStateGraphNode)
.map(n -> (SubStateGraphNode) n)
.toList();
}
/**
* Returns a list of nodes excluding sub-state graph nodes.
* @return a list of nodes excluding sub-state graph nodes
*/
public List<Node> exceptSubStateGraphNodes() {
return elements.stream().filter(n -> !(n instanceof SubStateGraphNode)).toList();
}
}
/**
* Container for edges in the graph.
*/
public static class Edges {
/**
* The collection of edges.
*/
public final List<Edge> elements;
/**
* Instantiates a new Edges container with the provided elements.
* @param elements the elements to initialize
*/
public Edges(Collection<Edge> elements) {
this.elements = new LinkedList<>(elements);
}
/**
* Instantiates a new empty Edges container.
*/
public Edges() {
this.elements = new LinkedList<>();
}
/**
* Retrieves the first edge matching the specified source identifier.
* @param sourceId the source identifier to match
* @return an optional containing the matched edge, or empty if none found
*/
public Optional<Edge> edgeBySourceId(String sourceId) {
return elements.stream().filter(e -> Objects.equals(e.sourceId(), sourceId)).findFirst();
}
/**
* Retrieves a list of edges targeting the specified node identifier.
* @param targetId the target identifier to match
* @return a list of edges targeting the specified identifier
*/
public List<Edge> edgesByTargetId(String targetId) {
return elements.stream().filter(e -> e.anyMatchByTargetId(targetId)).toList();
}
}
}
CompileConfig
主要用于配置图的编译过程,包括检查点保存器、中断设置、生命周期监听器等
| 字段名称 | 字段类型 | 描述 |
| saverConfig | SaverConfig | 保存器配置,用于管理检查点保存器,替代旧的checkpointerSaver字段 |
| lifecycleListeners | Deque | 图生命周期监听器队列,用于监听节点执行事件 |
| interruptsBefore | Set | 在指定节点之前发生的中断点集合 |
| interruptsAfter | Set | 在指定节点之后发生的中断点集合 |
| releaseThread | boolean | 线程释放标志,指示是否在执行期间释放线程 |
| observationRegistry | ObservationRegistry | 观察注册表,用于监控和追踪 |
对外暴露方法
| 方法 | 描述 |
| releaseThread | 返回线程释放标志的当前状态 |
| lifecycleListeners | 获取不可变的节点生命周期监听器列表 |
| observationRegistry | 获取用于监控和追踪的观察注册表 |
| interruptsBefore | 返回在指定节点之前发生的中断点集合 |
| interruptsAfter | 返回在指定节点之后发生的中断点集合 |
| checkpointSaver | 丛保存期配置中检索默认检查点保存器 |
package com.alibaba.cloud.ai.graph;
import com.alibaba.cloud.ai.graph.checkpoint.BaseCheckpointSaver;
import com.alibaba.cloud.ai.graph.checkpoint.config.SaverConfig;
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;
import io.micrometer.observation.ObservationRegistry;
import java.util.Collection;
import java.util.Deque;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.stream.Collectors;
import static com.alibaba.cloud.ai.graph.checkpoint.constant.SaverConstant.MEMORY;
import static java.util.Optional.ofNullable;
/**
* class is a configuration container for defining compile settings and behaviors. It
* includes various fields and methods to manage checkpoint savers and interrupts,
* providing both deprecated and current accessors.
*/
public class CompileConfig {
private SaverConfig saverConfig = new SaverConfig().register(MEMORY, new MemorySaver());
private Deque<GraphLifecycleListener> lifecycleListeners = new LinkedBlockingDeque<>(25);
// private BaseCheckpointSaver checkpointSaver; // replaced with SaverConfig
private Set<String> interruptsBefore = Set.of();
private Set<String> interruptsAfter = Set.of();
private boolean releaseThread = false;
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
/**
* Returns the current state of the thread release flag.
*
* @see BaseCheckpointSaver#release(RunnableConfig)
* @return true if the thread has been released, false otherwise
*/
public boolean releaseThread() {
return releaseThread;
}
/**
* Gets an unmodifiable list of node lifecycle listeners.
* @return The list of lifecycle listeners.
*/
public Queue<GraphLifecycleListener> lifecycleListeners() {
return lifecycleListeners;
}
/**
* Gets observation registry for monitoring and tracing.
* @return The observation registry instance.
*/
public ObservationRegistry observationRegistry() {
return observationRegistry;
}
/**
* Returns the array of interrupts that will occur before the specified node
* (deprecated).
* @return An array of interruptible nodes.
* @deprecated Use {@link #interruptsBefore()} instead for better immutability and
* type safety.
*/
@Deprecated
public String[] getInterruptBefore() {
return interruptsBefore.toArray(new String[0]);
}
/**
* Returns the array of interrupts that will occur after the specified node
* (deprecated).
* @return An array of interruptible nodes.
* @deprecated Use {@link #interruptsAfter()} instead for better immutability and type
* safety.
*/
@Deprecated
public String[] getInterruptAfter() {
return interruptsAfter.toArray(new String[0]);
}
/**
* Returns the set of interrupts that will occur before the specified node.
* @return An unmodifiable set of interruptible nodes.
*/
public Set<String> interruptsBefore() {
return interruptsBefore;
}
/**
* Returns the set of interrupts that will occur after the specified node.
* @return An unmodifiable set of interruptible nodes.
*/
public Set<String> interruptsAfter() {
return interruptsAfter;
}
/**
* Retrieves a checkpoint saver based on the specified type from the saver
* configuration.
* @param type The type of the checkpoint saver to retrieve.
* @return An Optional containing the checkpoint saver if available; otherwise, empty.
*/
public Optional<BaseCheckpointSaver> checkpointSaver(String type) {
return ofNullable(saverConfig.get(type));
}
/**
* Retrieves the default checkpoint saver from the saver configuration.
* @return An Optional containing the default checkpoint saver if available;
* otherwise, empty.
*/
public Optional<BaseCheckpointSaver> checkpointSaver() {
return ofNullable(saverConfig.get());
}
/**
* Returns a new instance of the builder with default configuration settings.
* @return A new Builder instance.
*/
public static Builder builder() {
return new Builder(new CompileConfig());
}
/**
* Returns a new instance of the builder initialized with the provided configuration.
* @param config The compile configuration to use as a base.
* @return A new Builder instance initialized with the given configuration.
*/
public static Builder builder(CompileConfig config) {
return new Builder(config);
}
/**
* Builder class for creating instances of CompileConfig. It allows setting various
* options such as savers, interrupts, and lifecycle listeners in a fluent manner.
*/
public static class Builder {
private final CompileConfig config;
/**
* Initializes the builder with the provided compile configuration.
* @param config The base configuration to start from.
*/
protected Builder(CompileConfig config) {
this.config = new CompileConfig(config);
}
/**
* Sets whether the thread should be released during execution.
* @param releaseThread Flag indicating whether to release the thread.
* @see BaseCheckpointSaver#release(RunnableConfig)
* @return This builder instance for method chaining.
*/
public Builder releaseThread(boolean releaseThread) {
this.config.releaseThread = releaseThread;
return this;
}
/**
* Sets the observation registry for monitoring and tracing.
* @param observationRegistry The ObservationRegistry to use.
* @return This builder instance for method chaining.
*/
public Builder observationRegistry(ObservationRegistry observationRegistry) {
this.config.observationRegistry = observationRegistry;
return this;
}
/**
* Sets the saver configuration for checkpoints.
* @param saverConfig The SaverConfig to use.
* @return This builder instance for method chaining.
*/
public Builder saverConfig(SaverConfig saverConfig) {
this.config.saverConfig = saverConfig;
return this;
}
/**
* Sets individual interrupt points that trigger before node execution using
* varargs.
* @param interruptBefore One or more strings representing interrupt points.
* @return This builder instance for method chaining.
*/
public Builder interruptBefore(String... interruptBefore) {
this.config.interruptsBefore = Set.of(interruptBefore);
return this;
}
/**
* Sets individual interrupt points that trigger after node execution using
* varargs.
* @param interruptAfter One or more strings representing interrupt points.
* @return This builder instance for method chaining.
*/
public Builder interruptAfter(String... interruptAfter) {
this.config.interruptsAfter = Set.of(interruptAfter);
return this;
}
/**
* Sets multiple interrupt points that trigger before node execution from a
* collection.
* @param interruptsBefore Collection of strings representing interrupt points.
* @return This builder instance for method chaining.
*/
public Builder interruptsBefore(Collection<String> interruptsBefore) {
this.config.interruptsBefore = interruptsBefore.stream().collect(Collectors.toUnmodifiableSet());
return this;
}
/**
* Sets multiple interrupt points that trigger after node execution from a
* collection.
* @param interruptsAfter Collection of strings representing interrupt points.
* @return This builder instance for method chaining.
*/
public Builder interruptsAfter(Collection<String> interruptsAfter) {
this.config.interruptsAfter = interruptsAfter.stream().collect(Collectors.toUnmodifiableSet());
return this;
}
/**
* Adds a lifecycle listener to monitor node execution events.
* @param listener The NodeLifecycleListener to add.
* @return This builder instance for method chaining.
*/
public Builder withLifecycleListener(GraphLifecycleListener listener) {
this.config.lifecycleListeners.offer(listener);
return this;
}
/**
* Finalizes the configuration and returns the compiled instance.
* @return The configured CompileConfig object.
*/
public CompileConfig build() {
return config;
}
}
/**
* Default constructor used internally to create a new configuration with default
* settings. Made private to ensure all instances are created through the builder
* pattern.
*/
private CompileConfig() {
}
/**
* Copy constructor to create a new instance based on an existing configuration.
* @param config The configuration to copy.
*/
private CompileConfig(CompileConfig config) {
this.saverConfig = config.saverConfig;
this.interruptsBefore = config.interruptsBefore;
this.interruptsAfter = config.interruptsAfter;
this.releaseThread = config.releaseThread;
this.lifecycleListeners = config.lifecycleListeners;
this.observationRegistry = config.observationRegistry;
}
}
CompiledGraph
图计算框架中的核心组件,代表一个已编译的图结
| 字段名称 | 字段类型 | 描述 |
| stateGraph | StateGraph | 关联原始状态图 |
| keyStrategyMap | Map | 键值策略对 |
| nodes | Map | 编译后的节点映射表 |
| edges | Map | 编译后的边映射表 |
| processedData | ProcessedNodesEdgesAndConfig | 处理节点、边的配置类 |
| maxIterations | int | 最大迭代次数,默认25次 |
| compileConfig | CompileConfig | 编译配置类 |
| INTERRUPTAFTER | String | 静态字符串,默认为"INTERRUPTED" |
对外暴露的方法
| 方法名称 | 描述 | |
| 图执行方法 | stream | - 无参:创建默认的流式输出 - (Map inputs):基于输入创建流式输出 - (Map inputs, RunnableConfig config):基于输入和配置创建流式输出 |
| streamFromInitialNode(OverAllState overAllState, RunnableConfig config) | 从初始节点开始流式执行 | |
| invoke | - (Map inputs):基于输入执行图 - (OverAllState overAllState, RunnableConfig config):基于初始状态执行 - (Map inputs, RunnableConfig config):执行图并返回最终状态 | |
| updateState(RunnableConfig config, Map values, String asNode) | 更新图状态 | |
| getStateHistory(RunnableConfig config) | 获取状态历史 | |
| getState(RunnableConfig config) | 获取特定配置的状态快照 | |
| stateOf(RunnableConfig config) | 获取状态快照 | |
| setMaxIterations(int maxIterations) | 设置最大迭代次数 | |
| 图表示方法 | getGraph | - (GraphRepresentation.Type type):生成默认图表示 - (GraphRepresentation.Type type, String title):生成带标题带图表示 - (GraphRepresentation.Type type, String title, boolean printConditionalEdges):生成指定类型的图表示 |
StreamMode 枚举类
- VALUES:值流模式
- SNAPSHOTS:快照流模式
AsyncNodeGenerator 内部类:负责图的异步执行和流式输出处理
核心字段说明:
-
Cursor cursor:游标对象,用于跟踪当前和下一个节点 ID
- String currentNodeId:当前节点 ID
- String nextNodeId:下一个节点 ID
- String resumeFrom:恢复执行的节点 ID
-
int iteration:当前迭代次数
-
RunnableConfig config:运行时配置
-
boolean returnFromEmbed:标记是否从嵌入生成器返回
-
Map<String, Object> currentState:当前状态的 Map 表示
-
OverAllState overAllState:封装的 OverAllState 对象,提供更丰富的状态操作
| 核心方法 | 描述 |
| next | 负责执行图的下一步操作 |
| evaluateAction | 评估并执行节点操作 |
| nextNodeId | 根据当前节点和状态确认下一个节点 |
| getEmbedGenerator | 从部分状态中提取嵌入式生成器 |
| processGeneratorOutput | 处理生成器输出数据 |
Node
图结构节点核心类,用于定义图中的单个节点,包括节点的标识符和可选的操作工厂
| 字段名称 | 字段类型 | 描述 |
| id | String | 节点的唯一标识符 |
| actionFactory | ActionFactory | 用于创建节点执行的动作 |
对外暴露方法
| 方法名称 | 描述 |
| isParallel | 检查节点是否为并行节点,当前实现总为false |
| withIdUpdated | 返回一个新节点实例,其ID经过指定函数转换 |
package com.alibaba.cloud.ai.graph.internal.node;
import com.alibaba.cloud.ai.graph.CompileConfig;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import java.util.Objects;
import java.util.function.Function;
import static java.lang.String.format;
/**
* Represents a node in a graph, characterized by a unique identifier and a factory for
* creating actions to be executed by the node. This is a generic record where the state
* type is specified by the type parameter {@code State}.
*
* {@link OverAllState}.
*
*/
public class Node {
public interface ActionFactory {
AsyncNodeActionWithConfig apply(CompileConfig config) throws GraphStateException;
}
private final String id;
private final ActionFactory actionFactory;
public Node(String id, ActionFactory actionFactory) {
this.id = id;
this.actionFactory = actionFactory;
}
/**
* Constructor that accepts only the `id` and sets `actionFactory` to null.
* @param id the unique identifier for the node
*/
public Node(String id) {
this(id, null);
}
/**
* id
* @return the unique identifier for the node.
*/
public String id() {
return id;
}
/**
* actionFactory
* @return a factory function that takes a {@link CompileConfig} and returns an
* {@link AsyncNodeActionWithConfig} instance for the specified {@code State}.
*/
public ActionFactory actionFactory() {
return actionFactory;
}
public boolean isParallel() {
// return id.startsWith(PARALLELPREFIX);
return false;
}
public Node withIdUpdated(Function<String, String> newId) {
return new Node(newId.apply(id), actionFactory);
}
/**
* Checks if this node is equal to another object.
* @param o the object to compare with
* @return true if this node is equal to the specified object, false otherwise
*/
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null)
return false;
if (o instanceof Node node) {
return Objects.equals(id, node.id);
}
return false;
}
/**
* Returns the hash code value for this node.
* @return the hash code value for this node
*/
@Override
public int hashCode() {
return Objects.hash(id);
}
@Override
public String toString() {
return format("Node(%s,%s)", id, actionFactory != null ? "action" : "null");
}
}
Edge
用于定义节点之间的连接关系,是一个 record 类
| 字段名称 | 字段类型 | 描述 |
| sourceId | String | 源节点ID,表示边的起始节点 |
| targets | List | 边的目标节点或条件 |
| 方法名称 | 描述 | |
| 构造 | Edge | - (String id):创建一个只有源节点的边 - (String sourceId, EdgeValue target):创建一个从源节点到单个目标节点的边 - (String sourceId, List targets):创建一个从源节点到多个目标节点的边 |
| 边的判断 | isParallel | 判断是否为并行边(目标节点数量大于1) |
| target | 单个目标节点,如果为并行边则抛出异常 | |
| anyMatchByTargetId | 检查目标节点中是否包含指定的目标节点ID | |
| withSourceAndTargetIdsUpdated | 更新源节点 ID 和目标节点值,返回新的 Edge 实例 | |
| validate | 验证边的有效性 |
package com.alibaba.cloud.ai.graph.internal.edge;
import com.alibaba.cloud.ai.graph.exception.Errors;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.internal.node.Node;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import static com.alibaba.cloud.ai.graph.StateGraph.START;
import static java.lang.String.format;
/**
* Represents an edge in a graph with a source ID and a target value.
*
* @param sourceId The ID of the source node.
* @param targets The targets value associated with the edge.
*/
public record Edge(String sourceId, List<EdgeValue> targets) {
public Edge(String sourceId, EdgeValue target) {
this(sourceId, List.of(target));
}
public Edge(String id) {
this(id, List.of());
}
public boolean isParallel() {
return targets.size() > 1;
}
public EdgeValue target() {
if (isParallel()) {
throw new IllegalStateException(format("Edge '%s' is parallel", sourceId));
}
return targets.get(0);
}
public boolean anyMatchByTargetId(String targetId) {
return targets().stream()
.anyMatch(v -> (v.id() != null) ? Objects.equals(v.id(), targetId)
: v.value().mappings().containsValue(targetId)
);
}
public Edge withSourceAndTargetIdsUpdated(Node node, Function<String, String> newSourceId,
Function<String, EdgeValue> newTarget) {
var newTargets = targets().stream().map(t -> t.withTargetIdsUpdated(newTarget)).toList();
return new Edge(newSourceId.apply(sourceId), newTargets);
}
public void validate(StateGraph.Nodes nodes) throws GraphStateException {
if (!Objects.equals(sourceId(), START) && !nodes.anyMatchById(sourceId())) {
throw Errors.missingNodeReferencedByEdge.exception(sourceId());
}
if (isParallel()) { // check for duplicates targets
Set<String> duplicates = targets.stream()
.collect(Collectors.groupingBy(EdgeValue::id, Collectors.counting())) // Group
// by
// element
// and
// count
// occurrences
.entrySet()
.stream()
.filter(entry -> entry.getValue() > 1) // Filter elements with more than
// one occurrence
.map(Map.Entry::getKey)
.collect(Collectors.toSet());
if (!duplicates.isEmpty()) {
throw Errors.duplicateEdgeTargetError.exception(sourceId(), duplicates);
}
}
for (EdgeValue target : targets) {
validate(target, nodes);
}
}
private void validate(EdgeValue target, StateGraph.Nodes nodes) throws GraphStateException {
if (target.id() != null) {
if (!Objects.equals(target.id(), StateGraph.END) && !nodes.anyMatchById(target.id())) {
throw Errors.missingNodeReferencedByEdge.exception(target.id());
}
}
else if (target.value() != null) {
for (String nodeId : target.value().mappings().values()) {
if (!Objects.equals(nodeId, StateGraph.END) && !nodes.anyMatchById(nodeId)) {
throw Errors.missingNodeInEdgeMapping.exception(sourceId(), nodeId);
}
}
}
else {
throw Errors.invalidEdgeTarget.exception(sourceId());
}
}
/**
* Checks if this edge is equal to another object.
* @param o the object to compare with
* @return true if this edge is equal to the specified object, false otherwise
*/
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
Edge node = (Edge) o;
return Objects.equals(sourceId, node.sourceId);
}
/**
* Returns the hash code value for this edge.
* @return the hash code value for this edge
*/
@Override
public int hashCode() {
return Objects.hash(sourceId);
}
}
EdgeValue
图边值的核心类,用于定义边的目标节点或条件,支持条件边
| 字段名称 | 字段类型 | 描述 |
| id | String | 目标节点的唯一标识符,用于固定边 |
| value | EdgeCondition | 与边值关联的条件,用于条件边 |
| 方法名称 | 描述 | |
| 构造 | EdgeValue | - (String id):创建一个只有目标节点 ID 的 EdgeValue 实例,条件为 null - (EdgeCondition value):创建一个只有条件的EdgeValue实例,目标节点ID为null - (String id, EdgeCondition value):创建一个包含目标节点 ID 和条件的 EdgeValue 实例 |
| 更新方法 | withTargetIdsUpdated(Function target) | 更新目标节点 ID,返回新的 EdgeValue 实例 - 如果当前 EdgeValue 有目标节点 ID,则应用目标函数 - 如果当前 EdgeValue 有条件,则更新条件中的映射关系 |
package com.alibaba.cloud.ai.graph.internal.edge;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* @param id The unique identifier for the edge value.
* @param value The condition associated with the edge value.
*/
public record EdgeValue(String id, EdgeCondition value) {
public EdgeValue(String id) {
this(id, null);
}
public EdgeValue(EdgeCondition value) {
this(null, value);
}
EdgeValue withTargetIdsUpdated(Function<String, EdgeValue> target) {
if (id != null) {
return target.apply(id);
}
var newMappings = value.mappings().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> {
var v = target.apply(e.getValue());
return (v.id() != null) ? v.id() : e.getValue();
}));
return new EdgeValue(null, new EdgeCondition(value.action(), newMappings));
}
}
EdgeCondition
用于定义条件边的条和映射关系
| 字段名称 | 字段类型 | 描述 |
| action | AsyncCommandAction | 异步命令动作,用于执行条件判断逻辑 |
| mappings | Map | 条件到目标节点的映射关系表 |
package com.alibaba.cloud.ai.graph.internal.edge;
import com.alibaba.cloud.ai.graph.action.AsyncCommandAction;
import java.util.Map;
import static java.lang.String.format;
/**
* Represents a condition associated with an edge in a graph.
*
* @param action The action to be performed asynchronously when the edge condition is met.
* @param mappings A map of string key-value pairs representing additional mappings for
* the edge condition.
*/
public record EdgeCondition(AsyncCommandAction action, Map<String, String> mappings) {
@Override
public String toString() {
return format("EdgeCondition[ %s, mapping=%s", action != null ? "action" : "null", mappings);
}
}
学习交流圈
你好,我是影子,曾先后在🐻、新能源、老铁就职,兼任Spring AI Alibaba开源社区的Committer。目前新建了一个交流群,一个人走得快,一群人走得远,关注公众号后可获得个人微信,添加微信后备注“交流”入群。另外,本人长期维护一套飞书云文档笔记,涵盖后端、大数据系统化的面试资料,可私信免费获取