1. 引言:为什么选择Spring AI
在AI浪潮席卷全球的今天,大语言模型(LLM)正以前所未有的深度和广度重塑软件开发。对于我们广大的Java开发者,特别是深耕于Spring生态的工程师而言,如何将AI能力优雅、高效、可靠地集成到现有的微服务和企业级应用中,是一个亟待解决的问题。
“我只是想在我的Spring Boot项目里调用一下OpenAI,为什么要搞得这么复杂?”
这可能是很多工程师的心声。直接使用RestTemplate或HttpClient去调用模型API当然可以,但很快你就会遇到一系列工程化挑战:
- 模型异构性:OpenAI、Azure OpenAI、Google Gemini、Anthropic Claude、国产的通义千问... 每个模型的API、认证方式、请求/响应格式都不同,切换模型的成本极高。
- 功能割裂:一个完整的AI应用,不仅仅是API调用。提示词工程、结构化输出、会话记忆、RAG(检索增强生成)、工具调用等,这些都需要大量的“胶水代码”来粘合。
- 企业级需求:生产环境需要考虑连接池管理、异步调用、错误重试、监控、日志、配置管理等一系列非功能性需求。
Spring AI 的诞生,正是为了解决这些痛点。 它并非要重复造轮子,而是将自身定位为AI应用的“Spring Boot”——通过强大的抽象和自动化配置,将复杂的AI集成过程“化繁为简”,让开发者能像使用spring-boot-starter-web一样,轻松地将AI能力融入到应用中。
选择Spring AI,不仅仅是选择一个工具库,更是拥抱一种将AI工程化的标准范式。它让我们能够:
- 专注业务逻辑:将开发者从繁琐的API适配和底层技术实现中解放出来。
- 拥抱Spring生态:与Spring Boot、Spring Cloud、Spring Data等无缝集成,享受整个生态带来的便利。
- 实现生产就绪:提供企业级应用所需的可移植性、可扩展性和可维护性。
本文将从环境搭建开始,逐步深入Spring AI的各大核心功能,并通过一个完整的企业级案例,带你领略其设计的精妙之处和强大的实战能力。
1.1 Spring AI的核心优势
graph TD
A[Spring AI 1.0] --> B[统一抽象层]
A --> C[生产就绪]
A --> D[Spring生态集成]
A --> E[多模型支持]
B --> B1[ChatModel接口]
B --> B2[EmbeddingModel接口]
B --> B3[ImageModel接口]
C --> C1[监控指标]
C --> C2[安全认证]
C --> C3[配置管理]
D --> D1[Spring Boot Starter]
D --> D2[Spring Cloud集成]
D --> D3[Spring Security]
E --> E1[OpenAI]
E --> E2[Azure OpenAI]
E --> E3[Anthropic]
E --> E4[本地模型]
想象一下,如果没有Spring Framework,我们每次都要手写数据库连接、事务管理、依赖注入的代码,那会是多么痛苦的体验。Spring AI就是要解决AI集成中的类似问题。
1.2 适用场景分析
适合使用Spring AI的场景:
- 企业级Java应用需要集成AI能力
- 需要支持多种AI模型的应用
- 对生产环境稳定性要求较高的项目
- 希望快速原型开发和迭代的团队
不太适合的场景:
- 纯AI研究项目(Python生态更丰富)
- 对性能要求极致的实时应用
- 简单的单次AI调用场景
2. 快速上手:环境搭建与基础配置
2.1 Maven依赖配置详解
首先,让我们从最基础的依赖配置开始。Spring AI采用了模块化设计,你可以根据需要选择特定的模型支持。
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.example</groupId>
<artifactId>spring-ai-demo</artifactId>
<version>1.0.0</version>
<packaging>jar</packaging>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.2.0</version>
<relativePath/>
</parent>
<properties>
<java.version>17</java.version>
<spring-ai.version>1.0.0</spring-ai.version>
</properties>
<dependencies>
<!-- Spring Boot基础依赖 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<!-- Spring AI核心依赖 -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-core</artifactId>
<version>${spring-ai.version}</version>
</dependency>
<!-- OpenAI支持 -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
<version>${spring-ai.version}</version>
</dependency>
<!-- 向量数据库支持 -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-chroma-store</artifactId>
<version>${spring-ai.version}</version>
</dependency>
<!-- 文档处理支持 -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-pdf-document-reader</artifactId>
<version>${spring-ai.version}</version>
</dependency>
<!-- 监控和指标 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-actuator</artifactId>
</dependency>
<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-registry-prometheus</artifactId>
</dependency>
<!-- 测试依赖 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<repositories>
<!-- Spring AI仓库 -->
<repository>
<id>spring-milestones</id>
<name>Spring Milestones</name>
<url>https://repo.spring.io/milestone</url>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
</repositories>
</project>
2.2 application.yml配置文件详解
配置文件是Spring AI应用的"大脑",合理的配置能让你的应用在不同环境下都能稳定运行。
spring:
application:
name: spring-ai-demo
# AI模型配置
ai:
openai:
api-key: ${OPENAI_API_KEY:your-api-key}
base-url: ${OPENAI_BASE_URL:https://api.openai.com}
chat:
options:
model: gpt-4
temperature: 0.7
max-tokens: 2048
top-p: 1.0
frequency-penalty: 0.0
presence-penalty: 0.0
# 连接池配置
http-client:
connect-timeout: 30s
read-timeout: 60s
max-connections: 20
max-connections-per-route: 10
embedding:
options:
model: text-embedding-ada-002
dimensions: 1536
# 向量数据库配置
vectorstore:
chroma:
host: ${CHROMA_HOST:localhost}
port: ${CHROMA_PORT:8000}
collection-name: ${CHROMA_COLLECTION:spring-ai-docs}
distance-function: COSINE
# Redis配置(用于会话存储)
data:
redis:
host: ${REDIS_HOST:localhost}
port: ${REDIS_PORT:6379}
password: ${REDIS_PASSWORD:}
database: 0
timeout: 2000ms
lettuce:
pool:
max-active: 20
max-idle: 10
min-idle: 5
max-wait: 2000ms
# 应用配置
app:
ai:
# 会话配置
conversation:
max-history: 10
ttl: 3600 # 1小时
cleanup-interval: 300 # 5分钟清理一次
# RAG配置
rag:
chunk-size: 1000
chunk-overlap: 200
max-results: 5
similarity-threshold: 0.7
# 重试配置
retry:
max-attempts: 3
backoff-delay: 1000
max-delay: 10000
# 限流配置
rate-limit:
requests-per-minute: 60
burst-capacity: 10
# 监控配置
management:
endpoints:
web:
exposure:
include: health,info,metrics,prometheus
endpoint:
health:
show-details: always
metrics:
export:
prometheus:
enabled: true
tags:
application: ${spring.application.name}
environment: ${spring.profiles.active:default}
# 日志配置
logging:
level:
org.springframework.ai: DEBUG
com.example: DEBUG
pattern:
console: "%d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level [%X{traceId},%X{spanId}] %logger{36} - %msg%n"
file: "%d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level [%X{traceId},%X{spanId}] %logger{36} - %msg%n"
file:
name: logs/spring-ai-demo.log
max-size: 100MB
max-history: 30
---
# 开发环境配置
spring:
config:
activate:
on-profile: dev
ai:
openai:
chat:
options:
temperature: 0.9 # 开发环境可以更有创意
logging:
level:
root: INFO
org.springframework.ai: DEBUG
---
# 生产环境配置
spring:
config:
activate:
on-profile: prod
ai:
openai:
chat:
options:
temperature: 0.3 # 生产环境更保守
max-tokens: 1024 # 控制成本
logging:
level:
root: WARN
com.example: INFO
app:
ai:
rate-limit:
requests-per-minute: 30 # 生产环境更严格的限流
3. 核心能力深度解析
3.1 大模型调用封装
3.1.1 统一的AI客户端封装设计
在企业级应用中,我们通常需要支持多种AI模型,并且要求有统一的调用接口。Spring AI通过抽象层很好地解决了这个问题。
package com.example.ai.service;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.cache.annotation.Cacheable;
import org.springframework.retry.annotation.Retryable;
import org.springframework.retry.annotation.Backoff;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.time.Duration;
/**
* AI模型服务统一封装
* 提供多模型支持、缓存、重试等企业级特性
*/
@Service
public class AIModelService {
@Autowired
private ChatClient primaryChatClient;
@Autowired
private ChatClient fallbackChatClient;
@Autowired
private AIMetricsService metricsService;
/**
* 同步聊天调用
* @param prompt 用户输入
* @param context 上下文参数
* @return AI响应
*/
@Retryable(
value = {Exception.class},
maxAttempts = 3,
backoff = @Backoff(delay = 1000, multiplier = 2)
)
@Cacheable(value = "ai-responses", key = "#prompt + #context.hashCode()")
public String chat(String prompt, Map<String, Object> context) {
long startTime = System.currentTimeMillis();
try {
// 构建提示词
Prompt aiPrompt = buildPrompt(prompt, context);
// 调用主模型
ChatResponse response = primaryChatClient.call(aiPrompt);
// 记录指标
metricsService.recordSuccess("primary",
System.currentTimeMillis() - startTime);
return response.getResult().getOutput().getContent();
} catch (Exception e) {
// 降级到备用模型
return fallbackChat(prompt, context, e);
}
}
/**
* 异步聊天调用
* @param prompt 用户输入
* @param context 上下文参数
* @return 异步响应
*/
public CompletableFuture<String> chatAsync(String prompt, Map<String, Object> context) {
return CompletableFuture.supplyAsync(() -> chat(prompt, context));
}
/**
* 流式聊天调用
* @param prompt 用户输入
* @param context 上下文参数
* @param callback 流式回调
*/
public void chatStream(String prompt, Map<String, Object> context,
StreamCallback callback) {
Prompt aiPrompt = buildPrompt(prompt, context);
primaryChatClient.stream(aiPrompt)
.subscribe(
chatResponse -> {
String content = chatResponse.getResult().getOutput().getContent();
callback.onNext(content);
},
error -> {
callback.onError(error);
metricsService.recordError("primary", error);
},
() -> callback.onComplete()
);
}
/**
* 备用模型调用
*/
private String fallbackChat(String prompt, Map<String, Object> context, Exception originalError) {
long startTime = System.currentTimeMillis();
try {
metricsService.recordFallback("primary", originalError);
Prompt aiPrompt = buildPrompt(prompt, context);
ChatResponse response = fallbackChatClient.call(aiPrompt);
metricsService.recordSuccess("fallback",
System.currentTimeMillis() - startTime);
return response.getResult().getOutput().getContent();
} catch (Exception fallbackError) {
metricsService.recordError("fallback", fallbackError);
throw new AIServiceException("Both primary and fallback models failed",
originalError, fallbackError);
}
}
/**
* 构建提示词
*/
private Prompt buildPrompt(String userInput, Map<String, Object> context) {
// 这里可以根据context动态构建提示词模板
if (context.containsKey("template")) {
PromptTemplate template = new PromptTemplate((String) context.get("template"));
return template.create(context);
} else {
return new Prompt(List.of(new UserMessage(userInput)));
}
}
/**
* 流式回调接口
*/
public interface StreamCallback {
void onNext(String content);
void onError(Throwable error);
void onComplete();
}
}
3.1.2 多模型适配器模式实现
适配器模式让我们能够统一不同AI服务提供商的接口差异。
graph TD
subgraph Your Application
A[MyService] --> B{ChatClient};
end
subgraph Spring AI Core
B -- uses --> C[ChatModel];
end
subgraph Spring AI Adapters
C -- can be implemented by --> D[OpenAiChatModel];
C -- can be implemented by --> E[VertexAiGeminiChatModel];
C -- can be implemented by --> F[OllamaChatModel];
end
A --> |"Tell me a joke"| B;
B --> |Generates a Prompt object| C;
C --> |Calls external API| D;
D --> |Sends JSON to OpenAI API| G[OpenAI API];
G --> |Returns JSON response| D;
D --> |Parses to ChatResponse| C;
C --> |Returns ChatResponse| B;
B --> |Extracts content| A;
package com.example.ai.adapter;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
/**
* 模型适配器接口
* 统一不同AI服务提供商的调用方式
*/
public interface ModelAdapter {
/**
* 获取模型名称
*/
String getModelName();
/**
* 获取提供商名称
*/
String getProvider();
/**
* 调用模型
*/
ChatResponse call(Prompt prompt);
/**
* 检查模型是否可用
*/
boolean isAvailable();
/**
* 获取模型配置
*/
ModelConfig getConfig();
}
package com.example.ai.adapter;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.stereotype.Component;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import lombok.extern.slf4j.Slf4j;
@Slf4j
@Component
@Qualifier("openaiAdapter")
public class OpenAIAdapter implements ModelAdapter {
@Autowired
private OpenAiChatClient chatClient;
@Autowired
private ModelConfig config;
@Override
public String getModelName() {
return config.getModelName();
}
@Override
public String getProvider() {
return config.getProvider();
}
@Override
public ChatResponse call(Prompt prompt) {
try {
log.debug("Calling OpenAI model with prompt: {}", prompt);
return chatClient.call(prompt);
} catch (Exception e) {
log.error("OpenAI model call failed", e);
throw new ModelCallException("OpenAI call failed", e);
}
}
@Override
public boolean isAvailable() {
try {
// 简单的健康检查
Prompt healthCheck = new Prompt("Hello");
ChatResponse response = chatClient.call(healthCheck);
return response != null && response.getResult() != null;
} catch (Exception e) {
log.warn("OpenAI health check failed", e);
return false;
}
}
@Override
public ModelConfig getConfig() {
return config;
}
}
3.1.3 异步调用和连接池管理
在高并发场景下,异步调用和连接池管理至关重要。
package com.example.ai.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadPoolExecutor;
@Configuration
@EnableAsync
public class AsyncConfig {
@Bean(name = "aiTaskExecutor")
public Executor aiTaskExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
// 核心线程数
executor.setCorePoolSize(10);
// 最大线程数
executor.setMaxPoolSize(50);
// 队列容量
executor.setQueueCapacity(200);
// 线程名前缀
executor.setThreadNamePrefix("AI-Task-");
// 拒绝策略:调用者运行
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
// 等待所有任务完成后关闭
executor.setWaitForTasksToCompleteOnShutdown(true);
// 等待时间
executor.setAwaitTerminationSeconds(60);
executor.initialize();
return executor;
}
}
package com.example.ai.service;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import java.util.concurrent.CompletableFuture;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
@Slf4j
@Service
public class AsyncAIService {
@Autowired
private AIModelService aiModelService;
/**
* 异步批量处理
* @param prompts 提示词列表
* @param context 上下文
* @return 异步结果列表
*/
@Async("aiTaskExecutor")
public CompletableFuture<List<String>> batchProcess(
List<String> prompts, Map<String, Object> context) {
log.info("Starting batch processing for {} prompts", prompts.size());
List<CompletableFuture<String>> futures = prompts.stream()
.map(prompt -> aiModelService.chatAsync(prompt, context))
.collect(Collectors.toList());
// 等待所有任务完成
CompletableFuture<Void> allOf = CompletableFuture.allOf(
futures.toArray(new CompletableFuture[0]));
return allOf.thenApply(v ->
futures.stream()
.map(CompletableFuture::join)
.collect(Collectors.toList())
).whenComplete((result, throwable) -> {
if (throwable != null) {
log.error("Batch processing failed", throwable);
} else {
log.info("Batch processing completed successfully, {} results", result.size());
}
});
}
/**
* 异步单个处理
* @param prompt 提示词
* @param context 上下文
* @return 异步结果
*/
@Async("aiTaskExecutor")
public CompletableFuture<String> process(String prompt, Map<String, Object> context) {
log.info("Starting async processing for prompt: {}", prompt);
return aiModelService.chatAsync(prompt, context)
.whenComplete((result, throwable) -> {
if (throwable != null) {
log.error("Async processing failed", throwable);
} else {
log.info("Async processing completed successfully");
}
});
}
}
3.1.4 错误处理与重试机制
在生产环境中,AI服务可能会因为网络波动、服务限流等原因导致临时失败。Spring AI结合Spring Retry提供了强大的重试机制。
package com.example.ai.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.retry.annotation.EnableRetry;
import org.springframework.retry.backoff.ExponentialBackOffPolicy;
import org.springframework.retry.policy.SimpleRetryPolicy;
import org.springframework.retry.support.RetryTemplate;
import java.util.HashMap;
import java.util.Map;
@Configuration
@EnableRetry
public class RetryConfig {
@Bean
public RetryTemplate retryTemplate() {
RetryTemplate retryTemplate = new RetryTemplate();
// 指数退避策略
ExponentialBackOffPolicy backOffPolicy = new ExponentialBackOffPolicy();
backOffPolicy.setInitialInterval(1000); // 初始间隔1秒
backOffPolicy.setMultiplier(2.0); // 每次加倍
backOffPolicy.setMaxInterval(10000); // 最大间隔10秒
retryTemplate.setBackOffPolicy(backOffPolicy);
// 重试策略:针对网络相关异常进行重试
Map<Class<? extends Throwable>, Boolean> retryableExceptions = new HashMap<>();
retryableExceptions.put(java.net.ConnectException.class, true);
retryableExceptions.put(java.net.SocketTimeoutException.class, true);
retryableExceptions.put(java.io.IOException.class, true);
retryableExceptions.put(org.springframework.web.client.ResourceAccessException.class, true);
SimpleRetryPolicy retryPolicy = new SimpleRetryPolicy(3, retryableExceptions, true);
retryTemplate.setRetryPolicy(retryPolicy);
return retryTemplate;
}
}
3.1.5 弹性模式实现
结合断路器、速率限制和超时控制,实现更加健壮的AI服务调用。
package com.example.ai.service;
import io.github.resilience4j.circuitbreaker.annotation.CircuitBreaker;
import io.github.resilience4j.ratelimiter.annotation.RateLimiter;
import io.github.resilience4j.timelimiter.annotation.TimeLimiter;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
@Service
public class ResilientAIService {
@Autowired
private AIModelService aiModelService;
@Autowired
private AIModelService fallbackService;
/**
* 弹性AI调用
* - 断路器:连续失败后快速失败
* - 速率限制:控制调用频率
* - 时间限制:控制超时时间
*/
@CircuitBreaker(name = "aiService", fallbackMethod = "fallbackChat")
@RateLimiter(name = "aiService")
@TimeLimiter(name = "aiService")
public CompletableFuture<String> resilientChat(String prompt, Map<String, Object> context) {
return CompletableFuture.supplyAsync(() -> {
return aiModelService.chat(prompt, context);
});
}
/**
* 降级方法
*/
public CompletableFuture<String> fallbackChat(
String prompt, Map<String, Object> context, Throwable throwable) {
return CompletableFuture.supplyAsync(() -> {
// 记录原始错误
LoggingService.logError("Primary AI service failed", throwable);
try {
// 尝试使用备用服务
return fallbackService.chat(prompt, context);
} catch (Exception e) {
// 如果备用服务也失败,返回预设的降级响应
LoggingService.logError("Fallback AI service also failed", e);
return "I'm sorry, but I'm currently experiencing technical difficulties. "
+ "Please try again later or contact support if the issue persists.";
}
});
}
}
3.2 提示工程实践
3.2.1 提示模板管理
Spring AI提供了强大的提示模板管理功能,可以帮助我们组织和复用提示模板。
package com.example.ai.prompt;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* 提示模板构建器
* 用于构建复杂的提示模板
*/
public class PromptTemplateBuilder {
private String systemMessage;
private final List<String> examples = new ArrayList<>();
private final Map<String, Object> variables = new HashMap<>();
/**
* 设置系统消息
*/
public PromptTemplateBuilder withSystemMessage(String systemMessage) {
this.systemMessage = systemMessage;
return this;
}
/**
* 添加示例
*/
public PromptTemplateBuilder addExample(String example) {
this.examples.add(example);
return this;
}
/**
* 添加变量
*/
public PromptTemplateBuilder addVariable(String key, Object value) {
this.variables.put(key, value);
return this;
}
/**
* 添加多个变量
*/
public PromptTemplateBuilder addVariables(Map<String, Object> variables) {
this.variables.putAll(variables);
return this;
}
/**
* 构建提示模板
*/
public Prompt build(String templateContent) {
List<Message> messages = new ArrayList<>();
// 添加系统消息
if (systemMessage != null && !systemMessage.isEmpty()) {
messages.add(new SystemMessage(systemMessage));
}
// 添加示例
for (String example : examples) {
messages.add(new UserMessage(example));
}
// 创建模板并填充变量
PromptTemplate template = new PromptTemplate(templateContent);
messages.add(new UserMessage(template.render(variables)));
return new Prompt(messages);
}
}
3.2.2 提示模板管理器
package com.example.ai.prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.stereotype.Component;
import org.springframework.beans.factory.annotation.Autowired;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* 提示模板管理器
* 负责加载和缓存提示模板
*/
@Component
public class PromptTemplateManager {
@Autowired
private ResourceLoader resourceLoader;
private final Map<String, PromptTemplate> templateCache = new ConcurrentHashMap<>();
/**
* 从类路径资源加载模板
*/
public PromptTemplate loadTemplate(String templatePath) {
return templateCache.computeIfAbsent(templatePath, path -> {
try {
Resource resource = resourceLoader.getResource("classpath:" + path);
String content = new String(
resource.getInputStream().readAllBytes(),
StandardCharsets.UTF_8
);
return new PromptTemplate(content);
} catch (IOException e) {
throw new RuntimeException("Failed to load template: " + path, e);
}
});
}
/**
* 获取代码审查提示模板
*/
public PromptTemplate getCodeReviewTemplate() {
return loadTemplate("templates/code-review.txt");
}
/**
* 获取文档QA提示模板
*/
public PromptTemplate getDocumentQATemplate() {
return loadTemplate("templates/document-qa.txt");
}
/**
* 获取客户服务提示模板
*/
public PromptTemplate getCustomerServiceTemplate() {
return loadTemplate("templates/customer-service.txt");
}
}
3.2.3 动态提示生成
package com.example.ai.prompt;
import org.springframework.stereotype.Component;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;
import java.util.Map;
/**
* 动态提示生成器
* 根据用户角色和场景生成上下文提示
*/
@Component
public class DynamicPromptGenerator {
/**
* 生成上下文感知的提示
*/
public String generateContextualPrompt(String basePrompt, String userRole, String scenario) {
StringBuilder promptBuilder = new StringBuilder();
// 添加时间上下文
promptBuilder.append("Current time: ")
.append(LocalDateTime.now().format(
DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")))
.append("\n\n");
// 添加角色上下文
if (userRole != null && !userRole.isEmpty()) {
promptBuilder.append("You are interacting with a user who has the role: ")
.append(userRole)
.append("\n");
}
// 添加场景上下文
if (scenario != null && !scenario.isEmpty()) {
promptBuilder.append("The current scenario is: ")
.append(scenario)
.append("\n\n");
}
// 添加基础提示
promptBuilder.append(basePrompt);
// 添加样式修饰符
promptBuilder.append("\n\nPlease provide a response that is professional, "
+ "concise, and directly addresses the query above.");
return promptBuilder.toString();
}
/**
* 生成带有格式要求的提示
*/
public String generateFormattedPrompt(String basePrompt, String outputFormat) {
StringBuilder promptBuilder = new StringBuilder(basePrompt);
promptBuilder.append("\n\nPlease format your response as ")
.append(outputFormat);
return promptBuilder.toString();
}
}
3.3 结构化输出处理
3.3.1 输出模式验证
package com.example.ai.output;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.networknt.schema.JsonSchema;
import com.networknt.schema.JsonSchemaFactory;
import com.networknt.schema.SpecVersion;
import com.networknt.schema.ValidationMessage;
import org.springframework.stereotype.Component;
import java.io.IOException;
import java.util.Set;
import java.util.stream.Collectors;
/**
* 输出模式验证器
* 验证AI输出是否符合预期的JSON模式
*/
@Component
public class OutputSchemaValidator {
private final ObjectMapper objectMapper = new ObjectMapper();
/**
* 验证JSON输出是否符合模式
* @param jsonOutput AI生成的JSON输出
* @param schemaJson JSON模式定义
* @return 验证结果
*/
public ValidationResult validate(String jsonOutput, String schemaJson) {
try {
// 解析JSON输出和模式
JsonNode outputNode = objectMapper.readTree(jsonOutput);
JsonNode schemaNode = objectMapper.readTree(schemaJson);
// 创建模式验证器
JsonSchemaFactory factory = JsonSchemaFactory.getInstance(SpecVersion.VersionFlag.V7);
JsonSchema schema = factory.getSchema(schemaNode);
// 执行验证
Set<ValidationMessage> validationMessages = schema.validate(outputNode);
if (validationMessages.isEmpty()) {
return new ValidationResult(true, null);
} else {
String errorDetails = validationMessages.stream()
.map(ValidationMessage::getMessage)
.collect(Collectors.joining("; "));
return new ValidationResult(false, errorDetails);
}
} catch (IOException e) {
return new ValidationResult(false, "Invalid JSON format: " + e.getMessage());
}
}
/**
* 验证结果类
*/
public static class ValidationResult {
private final boolean valid;
private final String errorDetails;
public ValidationResult(boolean valid, String errorDetails) {
this.valid = valid;
this.errorDetails = errorDetails;
}
public boolean isValid() {
return valid;
}
public String getErrorDetails() {
return errorDetails;
}
}
}
3.3.2 输出转换与映射
package com.example.ai.output;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.stereotype.Component;
import java.io.IOException;
/**
* 输出转换器
* 将AI输出转换为Java对象
*/
@Component
public class OutputConverter {
private final ObjectMapper objectMapper = new ObjectMapper();
/**
* 将AI响应转换为指定类型的对象
* @param response AI响应
* @param targetClass 目标类型
* @return 转换后的对象
*/
public <T> T convertToObject(ChatResponse response, Class<T> targetClass) {
String content = response.getResult().getOutput().getContent();
return convertJsonToObject(content, targetClass);
}
/**
* 将JSON字符串转换为指定类型的对象
* @param jsonContent JSON字符串
* @param targetClass 目标类型
* @return 转换后的对象
*/
public <T> T convertJsonToObject(String jsonContent, Class<T> targetClass) {
try {
return objectMapper.readValue(jsonContent, targetClass);
} catch (IOException e) {
throw new OutputConversionException(
"Failed to convert output to " + targetClass.getSimpleName(), e);
}
}
}
3.4 会话记忆管理
会话记忆的核心是将历史对话(用户输入和AI回复)存储起来,并在下一次请求时附加到提示词中。
sequenceDiagram
participant User
participant App as Spring Boot App
participant ChatMemory as Redis/DB
participant LLM as Large Language Model
User->>+App: "你好,我叫张三"
App->>LLM: Prompt: "你好,我叫张三"
LLM->>App: "你好张三!有什么可以帮你的吗?"
App->>+ChatMemory: Store [User: "你好,我叫张三", AI: "你好张三!..."] for session_id_123
App->>-User: "你好张三!有什么可以帮你的吗?"
User->>+App: "我叫什么名字?"
App->>+ChatMemory: Retrieve history for session_id_123
ChatMemory->>-App: [User: "你好,我叫张三", AI: "你好张三!..."]
App->>LLM: Prompt: "[History]\nUser: 你好,我叫张三\nAI: 你好张三!...\n\n[Current]\nUser: 我叫什么名字?"
LLM->>App: "你叫张三。"
App->>+ChatMemory: Store [User: "我叫什么名字?", AI: "你叫张三。"] for session_id_123
App->>-User: "你叫张三。"
3.4.1 会话状态管理
package com.example.ai.session;
import org.springframework.ai.chat.messages.Message;
import org.springframework.data.annotation.Id;
import org.springframework.data.redis.core.RedisHash;
import java.io.Serializable;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
/**
* 会话状态
* 存储用户与AI的对话历史
*/
@RedisHash("ai_sessions")
public class SessionState implements Serializable {
@Id
private String sessionId;
private String userId;
private List<MessageEntry> messageHistory = new ArrayList<>();
private Instant createdAt;
private Instant lastUpdatedAt;
private int tokenCount;
private Map<String, Object> metadata = new HashMap<>();
// 构造函数、getter和setter省略
/**
* 添加消息到历史记录
*/
public void addMessage(Message message, int tokens) {
MessageEntry entry = new MessageEntry(
message.getType().toString(),
message.getContent(),
Instant.now(),
tokens
);
messageHistory.add(entry);
tokenCount += tokens;
lastUpdatedAt = Instant.now();
}
/**
* 获取最近的N条消息
*/
public List<MessageEntry> getRecentMessages(int count) {
int startIndex = Math.max(0, messageHistory.size() - count);
return messageHistory.subList(startIndex, messageHistory.size());
}
/**
* 清除超过指定数量的旧消息
*/
public void pruneHistory(int maxMessages) {
if (messageHistory.size() > maxMessages) {
int removeCount = messageHistory.size() - maxMessages;
List<MessageEntry> prunedMessages = new ArrayList<>(messageHistory.subList(removeCount, messageHistory.size()));
messageHistory = prunedMessages;
}
}
/**
* 消息条目
*/
public static class MessageEntry implements Serializable {
private String type;
private String content;
private Instant timestamp;
private int tokens;
// 构造函数、getter和setter省略
}
}
3.4.2 会话管理服务
package com.example.ai.session;
import org.springframework.ai.chat.messages.Message;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import java.time.Instant;
import java.util.List;
import java.util.Optional;
/**
* 会话管理服务
* 管理用户与AI的对话会话
*/
@Service
public class SessionManager {
@Autowired
private SessionRepository sessionRepository;
@Autowired
private TokenCounter tokenCounter;
/**
* 创建新会话
*/
public SessionState createSession(String userId) {
SessionState session = new SessionState();
session.setSessionId(generateSessionId());
session.setUserId(userId);
session.setCreatedAt(Instant.now());
session.setLastUpdatedAt(Instant.now());
return sessionRepository.save(session);
}
/**
* 获取会话
*/
public Optional<SessionState> getSession(String sessionId) {
return sessionRepository.findById(sessionId);
}
/**
* 添加消息到会话
*/
public void addMessageToSession(String sessionId, Message message) {
sessionRepository.findById(sessionId).ifPresent(session -> {
int tokens = tokenCounter.countTokens(message.getContent());
session.addMessage(message, tokens);
sessionRepository.save(session);
});
}
/**
* 获取会话历史消息
*/
public List<Message> getSessionHistory(String sessionId, int maxMessages) {
return sessionRepository.findById(sessionId)
.map(session -> convertToMessages(session.getRecentMessages(maxMessages)))
.orElse(List.of());
}
/**
* 清理过期会话
*/
public void cleanupExpiredSessions(int maxAgeHours) {
Instant cutoffTime = Instant.now().minusSeconds(maxAgeHours * 3600);
List<SessionState> expiredSessions = sessionRepository.findByLastUpdatedAtBefore(cutoffTime);
sessionRepository.deleteAll(expiredSessions);
}
// 辅助方法省略
}
3.5 RAG检索增强
3.5.1 向量存储服务
package com.example.ai.rag;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import java.util.List;
/**
* 向量存储服务
* 管理文档的嵌入和检索
*/
@Service
public class VectorStoreService {
@Autowired
private VectorStore vectorStore;
@Autowired
private EmbeddingClient embeddingClient;
/**
* 添加文档到向量存储
*/
public void addDocuments(List<Document> documents) {
vectorStore.add(documents);
}
/**
* 根据查询检索相关文档
*/
public List<Document> search(String query, int k) {
return vectorStore.similaritySearch(query, k);
}
/**
* 删除文档
*/
public void removeDocuments(List<String> ids) {
vectorStore.delete(ids);
}
}
3.5.2 文档处理与分块
package com.example.ai.rag;
import org.springframework.ai.document.Document;
import org.springframework.ai.reader.pdf.PagePdfDocumentReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.stereotype.Component;
import java.io.File;
import java.util.List;
/**
* 文档处理器
* 负责文档的读取、分块和处理
*/
@Component
public class DocumentProcessor {
/**
* 处理PDF文档
*/
public List<Document> processPdfDocument(File pdfFile, String documentId) {
// 创建PDF阅读器
PagePdfDocumentReader reader = new PagePdfDocumentReader(pdfFile);
// 读取文档
List<Document> documents = reader.get();
// 为文档添加元数据
documents.forEach(doc -> {
doc.getMetadata().put("source", pdfFile.getName());
doc.getMetadata().put("documentId", documentId);
});
// 创建文本分割器
TokenTextSplitter splitter = new TokenTextSplitter();
splitter.setKeepSeparator(true);
splitter.setMaxTokens(512);
splitter.setMinTokens(64);
// 分割文档
return splitter.apply(documents);
}
}
3.5.3 RAG查询增强
package com.example.ai.rag;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.document.Document;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import java.util.ArrayList;
import java.util.List;
/**
* RAG服务
* 实现检索增强生成
*/
@Service
public class RAGService {
@Autowired
private VectorStoreService vectorStoreService;
@Autowired
private ChatClient chatClient;
/**
* 执行RAG查询
*/
public String query(String userQuery) {
// 检索相关文档
List<Document> relevantDocs = vectorStoreService.search(userQuery, 3);
// 构建上下文
String context = buildContext(relevantDocs);
// 构建提示
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage(
"You are a helpful assistant. Use the following context to answer the user's question. " +
"If the answer cannot be found in the context, say 'I don't have enough information to answer that.'\n\n" +
"Context:\n" + context
));
messages.add(new UserMessage(userQuery));
// 调用AI模型
Prompt prompt = new Prompt(messages);
return chatClient.call(prompt).getResult().getOutput().getContent();
}
/**
* 构建上下文
*/
private String buildContext(List<Document> documents) {
StringBuilder contextBuilder = new StringBuilder();
for (Document doc : documents) {
contextBuilder.append("--- Document: ")
.append(doc.getMetadata().get("source"))
.append(" ---\n")
.append(doc.getContent())
.append("\n\n");
}
return contextBuilder.toString();
}
}
3.5.4 增量学习与知识库更新
package com.example.ai.rag;
import org.springframework.ai.document.Document;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import java.io.File;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
/**
* 知识库更新服务
* 定期检查并更新知识库
*/
@Service
public class KnowledgeBaseUpdater {
@Autowired
private VectorStoreService vectorStoreService;
@Autowired
private DocumentProcessor documentProcessor;
private final Set<String> processedFiles = ConcurrentHashMap.newKeySet();
// ... existing code ...
/**
* 定期扫描文档目录并更新知识库
*/
@Scheduled(fixedDelayString = "${app.ai.rag.update-interval:3600000}")
public void updateKnowledgeBase() {
Path docsDirectory = Paths.get("docs");
File directory = docsDirectory.toFile();
if (directory.exists() && directory.isDirectory()) {
scanDirectory(directory);
}
}
/**
* 扫描目录中的文档
*/
private void scanDirectory(File directory) {
File[] files = directory.listFiles();
if (files == null) return;
for (File file : files) {
if (file.isDirectory()) {
scanDirectory(file);
} else if (isPdfFile(file) && !isProcessed(file)) {
processFile(file);
markAsProcessed(file);
}
}
}
/**
* 处理文件并添加到向量存储
*/
private void processFile(File file) {
try {
String documentId = file.getAbsolutePath();
List<Document> chunks = documentProcessor.processPdfDocument(file, documentId);
vectorStoreService.addDocuments(chunks);
} catch (Exception e) {
// 记录错误并继续处理其他文件
}
}
private boolean isPdfFile(File file) {
return file.getName().toLowerCase().endsWith(".pdf");
}
private boolean isProcessed(File file) {
return processedFiles.contains(file.getAbsolutePath() + "_" + file.lastModified());
}
private void markAsProcessed(File file) {
processedFiles.add(file.getAbsolutePath() + "_" + file.lastModified());
}
}
3.6 工具调用集成
3.6.1 工具定义与注册
package com.example.ai.tool;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 工具注解
* 用于标记可被AI调用的工具方法
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface AITool {
/**
* 工具名称
*/
String name();
/**
* 工具描述
*/
String description();
/**
* 工具类别
*/
ToolCategory category() default ToolCategory.GENERAL;
/**
* 工具类别枚举
*/
enum ToolCategory {
GENERAL,
DATA_ANALYSIS,
EXTERNAL_API,
SYSTEM_OPERATION
}
}
package com.example.ai.tool;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* 工具注册表
* 管理和执行AI可调用的工具
*/
@Service
public class ToolRegistry {
private final Map<String, ToolDefinition> tools = new HashMap<>();
@Autowired
private OpenAiChatClient chatClient;
// ... existing code ...
/**
* 注册工具
*/
public void registerTool(Object instance, Method method, AITool annotation) {
ToolDefinition tool = new ToolDefinition(
annotation.name(),
annotation.description(),
annotation.category(),
instance,
method
);
tools.put(annotation.name(), tool);
}
/**
* 执行工具调用
*/
public Object executeTool(String toolName, Map<String, Object> parameters) {
ToolDefinition tool = tools.get(toolName);
if (tool == null) {
throw new IllegalArgumentException("Unknown tool: " + toolName);
}
try {
return tool.execute(parameters);
} catch (Exception e) {
throw new ToolExecutionException("Failed to execute tool: " + toolName, e);
}
}
/**
* 获取工具列表
*/
public List<Map<String, Object>> getToolDefinitions() {
List<Map<String, Object>> definitions = new ArrayList<>();
for (ToolDefinition tool : tools.values()) {
Map<String, Object> definition = new HashMap<>();
definition.put("type", "function");
Map<String, Object> function = new HashMap<>();
function.put("name", tool.getName());
function.put("description", tool.getDescription());
// 这里可以添加参数模式定义
definition.put("function", function);
definitions.add(definition);
}
return definitions;
}
/**
* 使用工具增强AI响应
*/
public String getToolEnhancedResponse(String userQuery) {
// 构建系统提示,告知AI可用的工具
StringBuilder systemPrompt = new StringBuilder();
systemPrompt.append("You have access to the following tools:\n\n");
for (ToolDefinition tool : tools.values()) {
systemPrompt.append("- ").append(tool.getName())
.append(": ").append(tool.getDescription())
.append("\n");
}
systemPrompt.append("\nWhen you need to use a tool, respond with JSON in this format:\n")
.append("{\"tool\": \"tool_name\", \"parameters\": {\"param1\": \"value1\"}}\n\n")
.append("If you don't need to use a tool, just respond normally.");
// 构建提示
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage(systemPrompt.toString()));
messages.add(new UserMessage(userQuery));
// 调用AI
Prompt prompt = new Prompt(messages);
String response = chatClient.call(prompt).getResult().getOutput().getContent();
// 解析响应,检查是否需要调用工具
try {
// 简单的JSON检测
if (response.trim().startsWith("{") && response.contains("\"tool\":")) {
// 解析工具调用请求并执行
// 实际实现需要更复杂的JSON解析
// 这里简化处理
return "Tool execution result would appear here";
}
} catch (Exception e) {
// 解析失败,返回原始响应
}
return response;
}
}
3.6.2 工具实现示例
package com.example.ai.tool.impl;
import com.example.ai.tool.AITool;
import org.springframework.stereotype.Component;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;
import java.util.Map;
/**
* 实用工具集
* 提供各种辅助功能
*/
@Component
public class UtilityTools {
/**
* 获取当前时间
*/
@AITool(
name = "getCurrentTime",
description = "Get the current date and time in various formats"
)
public Map<String, String> getCurrentTime(String format) {
LocalDateTime now = LocalDateTime.now();
Map<String, String> result = new HashMap<>();
if (format == null || format.isEmpty()) {
format = "yyyy-MM-dd HH:mm:ss";
}
result.put("formatted", now.format(DateTimeFormatter.ofPattern(format)));
result.put("iso", now.toString());
return result;
}
/**
* 计算器功能
*/
@AITool(
name = "calculator",
description = "Perform basic arithmetic operations"
)
public double calculate(String operation, double a, double b) {
switch (operation.toLowerCase()) {
case "add":
return a + b;
case "subtract":
return a - b;
case "multiply":
return a * b;
case "divide":
if (b == 0) throw new IllegalArgumentException("Cannot divide by zero");
return a / b;
default:
throw new IllegalArgumentException("Unknown operation: " + operation);
}
}
}
3.7 MCP多模型链路处理
3.7.1 模型链路定义
package com.example.ai.mcp;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
/**
* 模型处理链
* 定义AI模型的处理流程
*/
public class ModelChain {
private final List<ChainNode> nodes = new ArrayList<>();
/**
* 添加处理节点
*/
public ModelChain addNode(String name, ChatClient client, Function<String, String> promptTransformer) {
nodes.add(new ChainNode(name, client, promptTransformer));
return this;
}
/**
* 执行处理链
*/
public ChainResult execute(String initialInput) {
String currentInput = initialInput;
List<NodeResult> results = new ArrayList<>();
for (ChainNode node : nodes) {
// 转换提示
String transformedPrompt = node.getPromptTransformer().apply(currentInput);
// 调用模型
Prompt prompt = new Prompt(List.of(new UserMessage(transformedPrompt)));
String output = node.getClient().call(prompt).getResult().getOutput().getContent();
// 记录结果
results.add(new NodeResult(node.getName(), transformedPrompt, output));
// 更新输入
currentInput = output;
}
return new ChainResult(results);
}
/**
* 链节点
*/
private static class ChainNode {
private final String name;
private final ChatClient client;
private final Function<String, String> promptTransformer;
public ChainNode(String name, ChatClient client, Function<String, String> promptTransformer) {
this.name = name;
this.client = client;
this.promptTransformer = promptTransformer;
}
public String getName() {
return name;
}
public ChatClient getClient() {
return client;
}
public Function<String, String> getPromptTransformer() {
return promptTransformer;
}
}
/**
* 节点结果
*/
public static class NodeResult {
private final String nodeName;
private final String input;
private final String output;
public NodeResult(String nodeName, String input, String output) {
this.nodeName = nodeName;
this.input = input;
this.output = output;
}
// Getters省略
}
/**
* 链执行结果
*/
public static class ChainResult {
private final List<NodeResult> nodeResults;
public ChainResult(List<NodeResult> nodeResults) {
this.nodeResults = nodeResults;
}
public String getFinalOutput() {
if (nodeResults.isEmpty()) {
return "";
}
return nodeResults.get(nodeResults.size() - 1).output;
}
public List<NodeResult> getAllResults() {
return nodeResults;
}
}
}
3.7.2 多模型协作示例
package com.example.ai.mcp;
import org.springframework.ai.chat.ChatClient;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
/**
* 内容生成服务
* 使用多模型链路生成高质量内容
*/
@Service
public class ContentGenerationService {
@Autowired
@Qualifier("gpt4ChatClient")
private ChatClient ideationModel;
@Autowired
@Qualifier("gpt35ChatClient")
private ChatClient draftingModel;
@Autowired
@Qualifier("anthropicChatClient")
private ChatClient editingModel;
/**
* 生成文章
*/
public String generateArticle(String topic, String targetAudience, int wordCount) {
ModelChain chain = new ModelChain()
// 创意发想阶段 - 使用高级模型生成大纲
.addNode("ideation", ideationModel, input ->
"Generate a detailed outline for an article about '" + topic + "' " +
"targeting " + targetAudience + ". " +
"Include main sections and key points for each section."
)
// 草稿撰写阶段 - 使用经济模型生成初稿
.addNode("drafting", draftingModel, outline ->
"Using the following outline, write a " + wordCount + " word article " +
"about '" + topic + "' for " + targetAudience + ". " +
"Make it engaging and informative.\n\nOutline:\n" + outline
)
// 编辑优化阶段 - 使用另一个模型进行润色
.addNode("editing", editingModel, draft ->
"Edit and improve the following article draft. " +
"Fix any grammar or style issues, improve flow and clarity, " +
"and ensure it's engaging for " + targetAudience + ".\n\n" + draft
);
return chain.execute(topic).getFinalOutput();
}
}
4. 实际案例展示
4.1 智能客服系统
package com.example.ai.cases.customerservice;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import java.util.ArrayList;
import java.util.List;
/**
* 智能客服服务
*/
@Service
public class CustomerSupportService {
@Autowired
private ChatClient chatClient;
@Autowired
private SessionManager sessionManager;
@Autowired
private KnowledgeBaseService knowledgeBaseService;
/**
* 处理客户查询
*/
public String handleCustomerQuery(String sessionId, String userQuery) {
// 获取会话历史
List<Message> messages = sessionManager.getSessionHistory(sessionId);
if (messages.isEmpty()) {
// 新会话,添加系统提示
messages.add(new SystemMessage(
"You are a helpful customer support assistant for our product. " +
"Be polite, concise, and helpful. If you don't know the answer, " +
"suggest escalating to a human agent."
));
}
// 检索相关知识库内容
String relevantInfo = knowledgeBaseService.retrieveRelevantInformation(userQuery);
// 添加知识库上下文和用户查询
if (!relevantInfo.isEmpty()) {
messages.add(new SystemMessage(
"Here is some relevant information that might help with the response:\n" +
relevantInfo
));
}
// 添加用户查询
messages.add(new UserMessage(userQuery));
// 调用AI模型
Prompt prompt = new Prompt(messages);
String response = chatClient.call(prompt).getResult().getOutput().getContent();
// 更新会话历史
sessionManager.updateSessionHistory(sessionId, userQuery, response);
return response;
}
}
4.2 智能文档分析
package com.example.ai.cases.docanalysis;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.document.Document;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import java.util.ArrayList;
import java.util.List;
/**
* 文档分析服务
*/
@Service
public class DocumentAnalysisService {
@Autowired
private ChatClient chatClient;
@Autowired
private DocumentProcessor documentProcessor;
/**
* 分析文档并生成摘要
*/
public String generateDocumentSummary(String documentPath) {
// 处理文档
List<Document> documentChunks = documentProcessor.processDocument(documentPath);
// 提取文档内容
StringBuilder contentBuilder = new StringBuilder();
for (Document chunk : documentChunks) {
contentBuilder.append(chunk.getContent()).append("\n\n");
}
String documentContent = contentBuilder.toString();
// 构建提示
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage(
"You are a document analysis assistant. Your task is to analyze the provided document " +
"and generate a comprehensive summary that captures the main points, key findings, " +
"and important details. The summary should be well-structured and easy to understand."
));
messages.add(new UserMessage(
"Please analyze the following document and provide a summary:\n\n" + documentContent
));
// 调用AI模型
Prompt prompt = new Prompt(messages);
return chatClient.call(prompt).getResult().getOutput().getContent();
}
/**
* 提取文档中的关键信息
*/
public List<KeyInsight> extractKeyInsights(String documentPath) {
// 处理文档
List<Document> documentChunks = documentProcessor.processDocument(documentPath);
// 提取文档内容
StringBuilder contentBuilder = new StringBuilder();
for (Document chunk : documentChunks) {
contentBuilder.append(chunk.getContent()).append("\n\n");
}
String documentContent = contentBuilder.toString();
// 构建提示
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage(
"Extract key insights from the document in JSON format. " +
"Each insight should have a 'category', 'title', and 'description'. " +
"Focus on the most important information."
));
messages.add(new UserMessage(
"Please extract key insights from the following document:\n\n" + documentContent
));
// 调用AI模型
Prompt prompt = new Prompt(messages);
String response = chatClient.call(prompt).getResult().getOutput().getContent();
// 解析JSON响应(实际实现需要更复杂的JSON处理)
// 这里简化处理
List<KeyInsight> insights = new ArrayList<>();
// 解析逻辑...
return insights;
}
/**
* 关键洞察类
*/
public static class KeyInsight {
private String category;
private String title;
private String description;
// 构造函数、getter和setter省略
}
}
5. 性能优化与最佳实践
5.1 性能优化策略
在生产环境中,AI服务的性能优化至关重要。以下是一些关键策略:
- 缓存策略:对于相同或相似的查询,使用缓存可以显著减少API调用和响应时间。
package com.example.ai.optimization;
import org.springframework.cache.annotation.Cacheable;
import org.springframework.stereotype.Service;
@Service
public class CachedAIService {
@Cacheable(value = "ai-responses", key = "#prompt + #temperature")
public String getCachedResponse(String prompt, double temperature) {
// 实际AI调用逻辑
return "AI response";
}
}
- 批处理请求:将多个请求合并为一个批处理请求,减少网络往返。
- 模型选择:根据任务复杂性选择合适的模型,简单任务使用轻量级模型。
- 提示词优化:精简提示词,减少token数量,降低成本和延迟。
- 连接池管理:使用连接池复用HTTP连接,减少连接建立开销。
5.2 监控与可观测性
package com.example.ai.monitoring;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Timer;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import javax.annotation.PostConstruct;
/**
* AI监控服务
*/
@Service
public class AIMonitoringService {
@Autowired
private MeterRegistry meterRegistry;
private Counter totalRequestsCounter;
private Counter successRequestsCounter;
private Counter failedRequestsCounter;
private Timer responseTimeTimer;
public AIMonitoringService() {
// 初始化计数器和计时器的代码移到@PostConstruct方法中
}
@PostConstruct
public void init() {
this.totalRequestsCounter = Counter.builder("ai.requests.total")
.description("Total number of AI requests")
.register(meterRegistry);
this.successRequestsCounter = Counter.builder("ai.requests.success")
.description("Number of successful AI requests")
.register(meterRegistry);
this.failedRequestsCounter = Counter.builder("ai.requests.failed")
.description("Number of failed AI requests")
.register(meterRegistry);
this.responseTimeTimer = Timer.builder("ai.response.time")
.description("AI response time")
.register(meterRegistry);
}
/**
* 记录请求开始
*/
public void recordRequestStart() {
totalRequestsCounter.increment();
}
/**
* 记录请求成功
*/
public void recordRequestSuccess(long durationMs) {
successRequestsCounter.increment();
responseTimeTimer.record(java.time.Duration.ofMillis(durationMs));
}
/**
* 记录请求失败
*/
public void recordRequestFailure(String errorType) {
failedRequestsCounter.increment();
Counter.builder("ai.requests.errors")
.tag("error.type", errorType)
.description("AI request errors by type")
.register(meterRegistry)
.increment();
}
}
5.3 成本控制策略
package com.example.ai.cost;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Value;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
/**
* AI成本控制服务
*/
@Service
public class CostControlService {
@Value("${app.ai.cost.token-limit-per-day}")
private int tokenLimitPerDay;
@Value("${app.ai.cost.token-cost-per-thousand}")
private double tokenCostPerThousand;
private final AtomicInteger tokensUsedToday = new AtomicInteger(0);
private final AtomicLong totalCost = new AtomicLong(0);
private long dayStartTimestamp = System.currentTimeMillis();
/**
* 检查是否超出每日token限制
*/
public boolean isWithinDailyLimit(int estimatedTokens) {
// 检查是否需要重置计数器
checkAndResetDailyCounter();
return tokensUsedToday.get() + estimatedTokens <= tokenLimitPerDay;
}
/**
* 记录token使用量
*/
public void recordTokenUsage(int promptTokens, int completionTokens) {
// 检查是否需要重置计数器
checkAndResetDailyCounter();
// 更新token使用量
int totalTokens = promptTokens + completionTokens;
tokensUsedToday.addAndGet(totalTokens);
// 计算成本(美元)
double cost = (totalTokens / 1000.0) * tokenCostPerThousand;
totalCost.addAndGet(Math.round(cost * 100)); // 转换为分
}
/**
* 获取今日使用的token数量
*/
public int getTokensUsedToday() {
checkAndResetDailyCounter();
return tokensUsedToday.get();
}
/**
* 获取总成本(美元)
*/
public double getTotalCost() {
return totalCost.get() / 100.0; // 转换回美元
}
/**
* 检查并重置每日计数器
*/
private void checkAndResetDailyCounter() {
long currentTime = System.currentTimeMillis();
long oneDayInMillis = 24 * 60 * 60 * 1000;
if (currentTime - dayStartTimestamp > oneDayInMillis) {
tokensUsedToday.set(0);
dayStartTimestamp = currentTime;
}
}
}
5.4 模型评估与质量控制
package com.example.ai.quality;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* 响应质量评估服务
*/
@Service
public class ResponseQualityEvaluator {
@Autowired
private ChatClient evaluatorModel;
/**
* 评估AI响应质量
*/
public QualityScore evaluateResponse(String userQuery, String aiResponse) {
// 构建评估提示
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage(
"You are an AI response evaluator. Your task is to evaluate the quality of an AI assistant's response " +
"to a user query. Rate the response on a scale of 1-10 for each of the following criteria: " +
"Relevance, Accuracy, Completeness, Clarity, and Helpfulness. " +
"Provide your ratings in JSON format with a brief explanation for each score."
));
messages.add(new UserMessage(
"User Query: " + userQuery + "\n\n" +
"AI Response: " + aiResponse + "\n\n" +
"Please evaluate this response."
));
// 调用评估模型
Prompt prompt = new Prompt(messages);
String evaluationResult = evaluatorModel.call(prompt).getResult().getOutput().getContent();
// 解析评估结果(实际实现需要更复杂的JSON处理)
// 这里简化处理
QualityScore score = new QualityScore();
// 解析逻辑...
return score;
}
/**
* 质量评分类
*/
public static class QualityScore {
private Map<String, Integer> scores = new HashMap<>();
private Map<String, String> explanations = new HashMap<>();
public void addScore(String criterion, int score, String explanation) {
scores.put(criterion, score);
explanations.put(criterion, explanation);
}
public double getAverageScore() {
if (scores.isEmpty()) {
return 0.0;
}
int sum = scores.values().stream().mapToInt(Integer::intValue).sum();
return (double) sum / scores.size();
}
public Map<String, Integer> getScores() {
return scores;
}
public Map<String, String> getExplanations() {
return explanations;
}
}
}
6. 总结与展望
6.1 Spring AI的未来发展
Spring AI 1.0的发布只是一个开始。随着AI技术的快速发展,我们可以期待Spring AI在以下方向继续演进:
- 更多模型支持:除了当前支持的OpenAI、Azure OpenAI和Anthropic模型外,未来版本可能会增加对更多开源和商业模型的支持,如Llama、Mistral、Gemini等。
- 多模态能力增强:增强对图像、音频和视频等多模态内容的处理能力,使Spring AI能够处理更复杂的应用场景。
- 本地模型部署:提供更便捷的本地模型部署方案,降低API调用成本,增强数据隐私保护。
- 更丰富的工具集成:增加更多预构建的工具和插件,简化特定场景下的AI应用开发。
- 企业级功能增强:进一步增强安全性、可观测性和治理能力,满足企业级应用的严格要求。
6.2 企业应用最佳实践
在企业环境中应用Spring AI时,建议遵循以下最佳实践:
- 分层架构设计:将AI能力封装在独立的服务层,与业务逻辑解耦,便于维护和升级。
- 渐进式集成:从小规模、非关键业务场景开始,逐步扩展到更核心的业务流程。
- 多模型策略:根据不同任务的复杂度和重要性,选择合适的模型,平衡性能和成本。
- 完善的监控体系:建立全面的监控指标,及时发现并解决性能问题和异常情况。
- 持续评估与优化:定期评估AI响应质量,根据反馈不断优化提示词和模型配置。
- 安全与合规:严格控制敏感数据的使用,确保AI应用符合相关法规和企业政策。
6.3 结语
Spring AI 1.0为Java开发者提供了一个强大而灵活的框架,使企业级AI应用的开发变得更加简单和高效。通过统一的抽象层、丰富的功能集和与Spring生态的无缝集成,它解决了传统AI集成方案中的许多痛点。
随着AI技术的不断发展和Spring AI框架的持续演进,我们有理由相信,基于Spring AI构建的企业级智能应用将在未来发挥越来越重要的作用,为企业创造更大的价值。
无论你是刚开始探索AI应用开发,还是已经有了丰富的经验,Spring AI都能为你提供所需的工具和抽象,帮助你构建下一代智能应用。现在正是开始这一旅程的最佳时机!