Spring AI Alibaba Graph:节点流式透传案例

255 阅读7分钟

原文链接:Spring AI Alibaba Graph:节点流式透传案例

教程说明

说明:本教程将采用2025年5月20日正式的GA版,给出如下内容

  1. 核心功能模块的快速上手教程
  2. 核心功能模块的源码级解读
  3. Spring ai alibaba增强的快速上手教程 + 源码级解读

版本:JDK21 + SpringBoot3.4.5 + SpringAI 1.0.0 + SpringAI Alibaba 1.0.0.2

将陆续完成如下章节教程。本章是第十章(Graph构建智能体)下的Graph 节点流式透传案例

代码开源如下:github.com/GTyingzi/sp…

微信推文往届解读可参考:

第一章内容

SpringAI(GA)的chat:快速上手+自动注入源码解读

SpringAI(GA):ChatClient调用链路解读

第二章内容

SpringAI的Advisor:快速上手+源码解读

SpringAI(GA):Sqlite、Mysql、Redis消息存储快速上手

第三章内容

SpringAI(GA):Tool工具整合—快速上手

SpringAI(GA):Tool源码+工具触发链路解读

第四章内容

SpringAI(GA):结构化输出的快速上手+源码解读

第五章内容

SpringAI(GA):内存、Redis、ES的向量数据库存储—快速上手

SpringAI(GA):向量数据库理论源码解读+Redis、Es接入源码

第六章内容

SpringAI(GA):RAG快速上手+模块化解读

SpringAI(GA):RAG下的ETL快速上手

SpringAI(GA):RAG下的ETL源码解读

第七章内容

SpringAI(GA):Nacos2下的分布式MCP

SpringAI(GA):Nacos3下的分布式MCP

SpringAI(GA):MCP源码解读

SpringAI(GA): SpringAI下的MCP源码解读

第八章内容

SpringAI(GA): 多模型评估篇

第九章内容

SpringAI(GA):观测篇快速上手+源码解读

第十章内容

Spring AI Alibaba Graph:快速入门

Spring Ai Alibaba Graph:多节点并行—快速上手

Graph 节点流式透传案例

[!TIP] Graph 将多个节点连接在一起进行工作流编排,其中某个节点在调用 AI 模型时,该节点需要流式的将 AI 模型响应结果给到前端

实战代码可见:github.com/GTyingzi/sp… 下的 graph 目录,本章代码为其 stream-node 模块

先看效果图:

pom.xml

这里使用 1.0.0.3-SNAPSHOT。在定义 StateGraph 方面和 1.0.0.2 有些变动

<properties>
    <spring-ai-alibaba.version>1.0.0.3-SNAPSHOT</spring-ai-alibaba.version>
</properties>

<dependencies>

    <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-autoconfigure-model-openai</artifactId>
    </dependency>

    <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-autoconfigure-model-chat-client</artifactId>
    </dependency>

    <dependency>
        <groupId>com.alibaba.cloud.ai</groupId>
        <artifactId>spring-ai-alibaba-graph-core</artifactId>
        <version>${spring-ai-alibaba.version}</version>
    </dependency>

    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-web</artifactId>
    </dependency>
</dependencies>

application.yml

server:
  port: 8080
spring:
  application:
    name: simple
  ai:
    openai:
      api-key: ${AIDASHSCOPEAPIKEY}
      base-url: https://dashscope.aliyuncs.com/compatible-mode
      chat:
        options:
          model: qwen-max

config

OverAllState 中存储的字段

  • query:用户的问题
  • expandernumber:扩展的数量
  • expandercontent:扩展的内容

定义 ExpanderNode,边的连接为:START -> expander -> END

package com.spring.ai.tutorial.graph.stream.config;

import com.alibaba.cloud.ai.graph.GraphRepresentation;
import com.alibaba.cloud.ai.graph.KeyStrategy;
import com.alibaba.cloud.ai.graph.KeyStrategyFactory;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import com.spring.ai.tutorial.graph.stream.node.ExpanderNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.HashMap;

import static com.alibaba.cloud.ai.graph.action.AsyncNodeAction.nodeasync;

@Configuration
public class GraphNodeStreamConfiguration {

    private static final Logger logger = LoggerFactory.getLogger(GraphNodeStreamConfiguration.class);

    @Bean
    public StateGraph streamGraph(ChatClient.Builder chatClientBuilder) throws GraphStateException {
        KeyStrategyFactory keyStrategyFactory = () -> {
            HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();

            // 用户输入
            keyStrategyHashMap.put("query", new ReplaceStrategy());
            keyStrategyHashMap.put("expandernumber", new ReplaceStrategy());
            keyStrategyHashMap.put("expandercontent", new ReplaceStrategy());
            return keyStrategyHashMap;
        };

        StateGraph stateGraph = new StateGraph(keyStrategyFactory)
                .addNode("expander", nodeasync(new ExpanderNode(chatClientBuilder)))
                .addEdge(StateGraph.START, "expander")
                .addEdge("expander", StateGraph.END);

        // 添加 PlantUML 打印
        GraphRepresentation representation = stateGraph.getGraph(GraphRepresentation.Type.PLANTUML,
                "expander flow");
        logger.info("\n=== expander UML Flow ===");
        logger.info(representation.content());
        logger.info("==================================\n");

        return stateGraph;
    }

}

node

ExpanderNode
package com.spring.ai.tutorial.graph.stream.node;

import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.streaming.StreamingChatGenerator;
import org.bsc.async.AsyncGenerator;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import reactor.core.publisher.Flux;

import java.util.Arrays;
import java.util.List;
import java.util.Map;

public class ExpanderNode implements NodeAction {

    private static final PromptTemplate DEFAULTPROMPTTEMPLATE = new PromptTemplate("You are an expert at information retrieval and search optimization.\nYour task is to generate {number} different versions of the given query.\n\nEach variant must cover different perspectives or aspects of the topic,\nwhile maintaining the core intent of the original query. The goal is to\nexpand the search space and improve the chances of finding relevant information.\n\nDo not explain your choices or add any other text.\nProvide the query variants separated by newlines.\n\nOriginal query: {query}\n\nQuery variants:\n");

    private final ChatClient chatClient;

    private final Integer NUMBER = 3;

    public ExpanderNode(ChatClient.Builder chatClientBuilder) {
        this.chatClient = chatClientBuilder.build();
    }

    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        String query = state.value("query", "");
        Integer expanderNumber = state.value("expandernumber", this.NUMBER);

        Flux<ChatResponse> chatResponseFlux = this.chatClient.prompt().user((user) -> user.text(DEFAULTPROMPTTEMPLATE.getTemplate()).param("number", expanderNumber).param("query", query)).stream().chatResponse();

        AsyncGenerator<? extends NodeOutput> generator = StreamingChatGenerator.builder()
                .startingNode("expanderllmstream")
                .startingState(state)
                .mapResult(response -> {
                    String text = response.getResult().getOutput().getText();
                    List<String> queryVariants = Arrays.asList(text.split("\n"));
                    return Map.of("expandercontent", queryVariants);
                }).build(chatResponseFlux);
        return Map.of("expandercontent", generator);
    }
}

controller

GraphStreamController
  • Sinks.Many<ServerSentEvent> sink:接收 Stream 数据
package com.spring.ai.tutorial.graph.stream.controller;

import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.spring.ai.tutorial.graph.stream.controller.GraphProcess.GraphProcess;
import org.bsc.async.AsyncGenerator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;

import java.util.HashMap;
import java.util.Map;

@RestController
@RequestMapping("/graph/stream")
public class GraphStreamController {

    private static final Logger logger = LoggerFactory.getLogger(GraphStreamController.class);

    private final CompiledGraph compiledGraph;

    public GraphStreamController(@Qualifier("streamGraph")StateGraph stateGraph) throws GraphStateException {
        this.compiledGraph = stateGraph.compile();
    }

    @GetMapping(value = "/expand", produces = MediaType.TEXTEVENTSTREAMVALUE)
    public Flux<ServerSentEvent<String>> expand(@RequestParam(value = "query", defaultValue = "你好,很高兴认识你,能简单介绍一下自己吗?", required = false) String query,
                                                @RequestParam(value = "expandernumber", defaultValue = "3", required = false) Integer  expanderNumber,
                                                @RequestParam(value = "threadid", defaultValue = "yingzi", required = false) String threadId){
        RunnableConfig runnableConfig = RunnableConfig.builder().threadId(threadId).build();
        Map<String, Object> objectMap = new HashMap<>();
        objectMap.put("query", query);
        objectMap.put("expandernumber", expanderNumber);

        GraphProcess graphProcess = new GraphProcess(this.compiledGraph);
        Sinks.Many<ServerSentEvent<String>> sink = Sinks.many().unicast().onBackpressureBuffer();
        AsyncGenerator<NodeOutput> resultFuture = compiledGraph.stream(objectMap, runnableConfig);
        graphProcess.processStream(resultFuture, sink);

        return sink.asFlux()
                .doOnCancel(() -> logger.info("Client disconnected from stream"))
                .doOnError(e -> logger.error("Error occurred during streaming", e));
    }


}
GraphProcess
  • ExecutorService executor:配置线程池,获取 stream 流

将结果写入到 sink 中

package com.spring.ai.tutorial.graph.stream.controller.GraphProcess;

import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.bsc.async.AsyncGenerator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.codec.ServerSentEvent;
import reactor.core.publisher.Sinks;

import java.util.Map;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class GraphProcess {

    private static final Logger logger = LoggerFactory.getLogger(GraphProcess.class);

    private final ExecutorService executor = Executors.newSingleThreadExecutor();

    private CompiledGraph compiledGraph;

    public GraphProcess(CompiledGraph compiledGraph) {
        this.compiledGraph = compiledGraph;
    }

    public void processStream(AsyncGenerator<NodeOutput> generator, Sinks.Many<ServerSentEvent<String>> sink) {
        executor.submit(() -> {
            generator.forEachAsync(output -> {
                try {
                    logger.info("output = {}", output);
                    String nodeName = output.node();
                    String content;
                    if (output instanceof StreamingOutput streamingOutput) {
                        content = JSON.toJSONString(Map.of(nodeName, streamingOutput.chunk()));
                    } else {
                        JSONObject nodeOutput = new JSONObject();
                        nodeOutput.put("data", output.state().data());
                        nodeOutput.put("node", nodeName);
                        content = JSON.toJSONString(nodeOutput);
                    }
                    sink.tryEmitNext(ServerSentEvent.builder(content).build());
                } catch (Exception e) {
                    throw new CompletionException(e);
                }
            }).thenAccept(v -> {
                // 正常完成
                sink.tryEmitComplete();
            }).exceptionally(e -> {
                sink.tryEmitError(e);
                return null;
            });
        });
    }
}
效果

学习交流圈

你好,我是影子,曾先后在🐻、新能源、老铁就职,现在是一名AI研发工程师,同时作为Spring AI Alibaba开源社区的Committer。目前新建了一个交流群,一个人走得快,一群人走得远,另外,本人长期维护一套飞书云文档笔记,涵盖后端、大数据系统化的面试资料,可私信免费获取