StateGraph 详细分析

7 阅读4分钟

StateGraph 详细分析

源码路径:spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/StateGraph.java


1. 类的定位与职责

StateGraph 是整个图引擎的定义层(Definition Layer),负责描述一个有状态工作流的拓扑结构。它不负责执行,只负责声明图的结构——哪些节点、哪些边、如何连接。真正的执行由 compile() 产生的 CompiledGraph 承担。

StateGraph(定义)  ──compile()──▶  CompiledGraph(执行)

2. 核心数据结构

特殊节点常量

常量说明
START__START__图的入口,所有流程必须从此出发
END__END__图的出口,表示流程正常结束
ERROR__ERROR__错误路由节点,流程异常时跳转
NODE_BEFORE__NODE_BEFORE__节点执行前的钩子标识
NODE_AFTER__NODE_AFTER__节点执行后的钩子标识

所有常量以 __ 为前缀,与用户自定义节点隔离(Node.PRIVATE_PREFIX = "__")。

Nodes 容器(内部类)

public static class Nodes {
    public final Set<Node> elements;  // LinkedHashSet,保证去重且保持插入顺序
}
方法说明
anyMatchById(String id)判断是否存在指定 id 的节点
onlySubStateGraphNodes()获取所有未编译子图节点
exceptSubStateGraphNodes()获取排除子图节点后的普通节点列表

Edges 容器(内部类)

public static class Edges {
    public final List<Edge> elements;  // LinkedList,便于顺序查找
}
方法说明
edgeBySourceId(String sourceId)按源节点 id 查找第一条匹配的边
edgesByTargetId(String targetId)反向查找所有指向目标节点的边

3. 节点类型体系

addNode 提供 6 个重载,对应 4 种节点类型:

重载签名实际节点类型说明
addNode(id, AsyncNodeAction)Node普通节点,执行业务逻辑并返回状态更新
addNode(id, AsyncNodeActionWithConfig)Node带编译配置的普通节点
addNode(id, AsyncCommandAction, mappings)Node + 条件边节点本身是路由决策点(单路由)
addNode(id, AsyncMultiCommandAction, mappings)Node + 并行条件边节点决定并行路由(多路由)
addNode(id, CompiledGraph)SubCompiledGraphNode已编译子图(直接内嵌,共享父图状态)
addNode(id, StateGraph)SubStateGraphNode未编译子图(父图编译时一并编译)

后两种子图节点与父图共享同一份 OverAllState,状态透传无需额外转换。


4. 边类型体系

普通边

addEdge(sourceId, targetId)               // 1 对 1
addEdge(List<String> sourceIds, targetId) // 多对 1(汇聚)
addEdge(sourceId, List<String> targetIds) // 1 对多(广播)

同一 sourceId 可累积多个 target,形成并行分叉Edge.isParallel() == true)。

条件边(单路由)

addConditionalEdges(sourceId, AsyncCommandAction, Map<String, String> mappings)
addConditionalEdges(sourceId, AsyncEdgeAction, Map<String, String> mappings)
addConditionalEdges(sourceId, AsyncEdgeActionWithConfig, Map<String, String> mappings)

AsyncCommandAction 在运行时返回一个 Command(包含下一节点 key),通过 mappings 映射到实际节点 ID。

并行条件边(多路由)

addParallelConditionalEdges(sourceId, AsyncMultiCommandAction, Map<String, String> mappings)

AsyncMultiCommandAction 返回 MultiCommand(包含多个节点 key),可同时路由到多个节点并行执行


5. 序列化机制

public static final StateSerializer DEFAULT_JACKSON_SERIALIZER =
    new SpringAIJacksonStateSerializer(OverAllState::new, new ObjectMapper());
序列化器说明
SpringAIJacksonStateSerializer默认,基于 Jackson JSON
SpringAIStateSerializer已废弃
PlainTextStateSerializer已废弃
  • StateSerializer:决定 OverAllState 如何在检查点(checkpoint)中持久化
  • KeyStrategyFactory:控制状态中每个 key 的合并策略(覆盖、追加、自定义聚合)

6. 校验机制(validateGraph()

编译前自动执行,按以下规则检查:

  1. 所有节点自校验(id 不能为空、不能以 __ 开头)
  2. 必须存在从 START 出发的边,否则抛 missingEntryPoint
  3. 所有边的 source/target 节点必须存在于节点集合
  4. 并行边不能有重复 target
  5. 条件边的 mappings 中引用的节点必须存在(END 除外)

7. 编译流程

// 推荐:带自定义配置(生产环境,注入持久化 Saver)
CompiledGraph compile(CompileConfig config)

// 默认:使用 MemorySaver(开发调试)
CompiledGraph compile()

compile() 默认使用 MemorySaver 作为检查点存储,适合开发调试。生产场景应通过 CompileConfig 注入 JDBC / Redis 等持久化 Saver。


8. 图可视化

getGraph(GraphRepresentation.Type type, String title, boolean printConditionalEdges)
getGraph(GraphRepresentation.Type type, String title)
getGraph(GraphRepresentation.Type type)

将图结构序列化为可视化代码(如 Mermaid 流程图),供 Studio 调试界面渲染。


9. 设计模式

模式体现
Builder / Fluent API所有 addNode / addEdge 返回 this,支持链式调用
工厂方法KeyStrategyFactoryNode.ActionFactory 延迟到编译时创建实例
定义与执行分离StateGraph(定义)→ CompiledGraph(执行)
策略模式StateSerializerKeyStrategy 可替换
模板方法SubStateGraphNode 延迟到父图编译时才编译子图

10. 典型使用示例

var graph = new StateGraph(keyStrategyFactory)
    // 普通节点
    .addNode("llm", state ->
        completedFuture(Map.of("result", callLLM(state))))
    // 工具节点
    .addNode("tool", state ->
        completedFuture(Map.of("result", callTool(state))))
    // 入口 → llm
    .addEdge(START, "llm")
    // llm 根据结果决定路由
    .addConditionalEdges("llm",
        (state, cfg) -> completedFuture(
            needTool(state) ? new Command("tool") : new Command("end")),
        Map.of("tool", "tool", "end", END))
    // 工具执行完回到 llm
    .addEdge("tool", "llm");

// 编译(开发调试)
CompiledGraph compiled = graph.compile();

// 编译(生产,使用 JDBC 持久化检查点)
CompiledGraph compiled = graph.compile(
    CompileConfig.builder()
        .saverConfig(SaverConfig.builder().register(jdbcSaver).build())
        .build());

11. 构造器一览

构造器说明
StateGraph()无名称,空 key 策略,默认 Jackson 序列化
StateGraph(KeyStrategyFactory)自定义 key 策略,默认 Jackson 序列化
StateGraph(String, KeyStrategyFactory)带名称,默认 Jackson 序列化
StateGraph(KeyStrategyFactory, StateSerializer)自定义 key 策略和序列化器
StateGraph(String, KeyStrategyFactory, StateSerializer)完整参数(推荐)