Java 开发者如何拥抱大模型?Spring AI 1.0 给你答案

1,019 阅读18分钟

1. 引言:为什么选择Spring AI

在AI浪潮席卷全球的今天,大语言模型(LLM)正以前所未有的深度和广度重塑软件开发。对于我们广大的Java开发者,特别是深耕于Spring生态的工程师而言,如何将AI能力优雅、高效、可靠地集成到现有的微服务和企业级应用中,是一个亟待解决的问题。

“我只是想在我的Spring Boot项目里调用一下OpenAI,为什么要搞得这么复杂?”

这可能是很多工程师的心声。直接使用RestTemplateHttpClient去调用模型API当然可以,但很快你就会遇到一系列工程化挑战:

  • 模型异构性:OpenAI、Azure OpenAI、Google Gemini、Anthropic Claude、国产的通义千问... 每个模型的API、认证方式、请求/响应格式都不同,切换模型的成本极高。
  • 功能割裂:一个完整的AI应用,不仅仅是API调用。提示词工程、结构化输出、会话记忆、RAG(检索增强生成)、工具调用等,这些都需要大量的“胶水代码”来粘合。
  • 企业级需求:生产环境需要考虑连接池管理、异步调用、错误重试、监控、日志、配置管理等一系列非功能性需求。

Spring AI 的诞生,正是为了解决这些痛点。 它并非要重复造轮子,而是将自身定位为AI应用的“Spring Boot”——通过强大的抽象和自动化配置,将复杂的AI集成过程“化繁为简”,让开发者能像使用spring-boot-starter-web一样,轻松地将AI能力融入到应用中。

选择Spring AI,不仅仅是选择一个工具库,更是拥抱一种将AI工程化的标准范式。它让我们能够:

  1. 专注业务逻辑:将开发者从繁琐的API适配和底层技术实现中解放出来。
  2. 拥抱Spring生态:与Spring Boot、Spring Cloud、Spring Data等无缝集成,享受整个生态带来的便利。
  3. 实现生产就绪:提供企业级应用所需的可移植性、可扩展性和可维护性。

本文将从环境搭建开始,逐步深入Spring AI的各大核心功能,并通过一个完整的企业级案例,带你领略其设计的精妙之处和强大的实战能力。

1.1 Spring AI的核心优势

graph TD
    A[Spring AI 1.0] --> B[统一抽象层]
    A --> C[生产就绪]
    A --> D[Spring生态集成]
    A --> E[多模型支持]
    
    B --> B1[ChatModel接口]
    B --> B2[EmbeddingModel接口]
    B --> B3[ImageModel接口]
    
    C --> C1[监控指标]
    C --> C2[安全认证]
    C --> C3[配置管理]
    
    D --> D1[Spring Boot Starter]
    D --> D2[Spring Cloud集成]
    D --> D3[Spring Security]
    
    E --> E1[OpenAI]
    E --> E2[Azure OpenAI]
    E --> E3[Anthropic]
    E --> E4[本地模型]

想象一下,如果没有Spring Framework,我们每次都要手写数据库连接、事务管理、依赖注入的代码,那会是多么痛苦的体验。Spring AI就是要解决AI集成中的类似问题。

1.2 适用场景分析

适合使用Spring AI的场景:

  • 企业级Java应用需要集成AI能力
  • 需要支持多种AI模型的应用
  • 对生产环境稳定性要求较高的项目
  • 希望快速原型开发和迭代的团队

不太适合的场景:

  • 纯AI研究项目(Python生态更丰富)
  • 对性能要求极致的实时应用
  • 简单的单次AI调用场景

2. 快速上手:环境搭建与基础配置

2.1 Maven依赖配置详解

首先,让我们从最基础的依赖配置开始。Spring AI采用了模块化设计,你可以根据需要选择特定的模型支持。

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 
         http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    
    <groupId>com.example</groupId>
    <artifactId>spring-ai-demo</artifactId>
    <version>1.0.0</version>
    <packaging>jar</packaging>
    
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.2.0</version>
        <relativePath/>
    </parent>
    
    <properties>
        <java.version>17</java.version>
        <spring-ai.version>1.0.0</spring-ai.version>
    </properties>
    
    <dependencies>
        <!-- Spring Boot基础依赖 -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
        </dependency>
        
        <!-- Spring AI核心依赖 -->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-core</artifactId>
            <version>${spring-ai.version}</version>
        </dependency>
        
        <!-- OpenAI支持 -->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-openai-spring-boot-starter</artifactId>
            <version>${spring-ai.version}</version>
        </dependency>
        
        <!-- 向量数据库支持 -->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-chroma-store</artifactId>
            <version>${spring-ai.version}</version>
        </dependency>
        
        <!-- 文档处理支持 -->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-pdf-document-reader</artifactId>
            <version>${spring-ai.version}</version>
        </dependency>
        
        <!-- 监控和指标 -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-actuator</artifactId>
        </dependency>
        
        <dependency>
            <groupId>io.micrometer</groupId>
            <artifactId>micrometer-registry-prometheus</artifactId>
        </dependency>
        
        <!-- 测试依赖 -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
    </dependencies>
    
    <repositories>
        <!-- Spring AI仓库 -->
        <repository>
            <id>spring-milestones</id>
            <name>Spring Milestones</name>
            <url>https://repo.spring.io/milestone</url>
            <snapshots>
                <enabled>false</enabled>
            </snapshots>
        </repository>
    </repositories>
</project>

2.2 application.yml配置文件详解

配置文件是Spring AI应用的"大脑",合理的配置能让你的应用在不同环境下都能稳定运行。

spring:
  application:
    name: spring-ai-demo
  
  # AI模型配置
  ai:
    openai:
      api-key: ${OPENAI_API_KEY:your-api-key}
      base-url: ${OPENAI_BASE_URL:https://api.openai.com}
      chat:
        options:
          model: gpt-4
          temperature: 0.7
          max-tokens: 2048
          top-p: 1.0
          frequency-penalty: 0.0
          presence-penalty: 0.0
        # 连接池配置
        http-client:
          connect-timeout: 30s
          read-timeout: 60s
          max-connections: 20
          max-connections-per-route: 10
      
      embedding:
        options:
          model: text-embedding-ada-002
          dimensions: 1536
    
    # 向量数据库配置
    vectorstore:
      chroma:
        host: ${CHROMA_HOST:localhost}
        port: ${CHROMA_PORT:8000}
        collection-name: ${CHROMA_COLLECTION:spring-ai-docs}
        distance-function: COSINE
        
  # Redis配置(用于会话存储)
  data:
    redis:
      host: ${REDIS_HOST:localhost}
      port: ${REDIS_PORT:6379}
      password: ${REDIS_PASSWORD:}
      database: 0
      timeout: 2000ms
      lettuce:
        pool:
          max-active: 20
          max-idle: 10
          min-idle: 5
          max-wait: 2000ms

# 应用配置
app:
  ai:
    # 会话配置
    conversation:
      max-history: 10
      ttl: 3600  # 1小时
      cleanup-interval: 300  # 5分钟清理一次
    
    # RAG配置
    rag:
      chunk-size: 1000
      chunk-overlap: 200
      max-results: 5
      similarity-threshold: 0.7
    
    # 重试配置
    retry:
      max-attempts: 3
      backoff-delay: 1000
      max-delay: 10000
    
    # 限流配置
    rate-limit:
      requests-per-minute: 60
      burst-capacity: 10

# 监控配置
management:
  endpoints:
    web:
      exposure:
        include: health,info,metrics,prometheus
  endpoint:
    health:
      show-details: always
  metrics:
    export:
      prometheus:
        enabled: true
    tags:
      application: ${spring.application.name}
      environment: ${spring.profiles.active:default}

# 日志配置
logging:
  level:
    org.springframework.ai: DEBUG
    com.example: DEBUG
  pattern:
    console: "%d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level [%X{traceId},%X{spanId}] %logger{36} - %msg%n"
    file: "%d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level [%X{traceId},%X{spanId}] %logger{36} - %msg%n"
  file:
    name: logs/spring-ai-demo.log
    max-size: 100MB
    max-history: 30

---
# 开发环境配置
spring:
  config:
    activate:
      on-profile: dev
  ai:
    openai:
      chat:
        options:
          temperature: 0.9  # 开发环境可以更有创意
logging:
  level:
    root: INFO
    org.springframework.ai: DEBUG

---
# 生产环境配置
spring:
  config:
    activate:
      on-profile: prod
  ai:
    openai:
      chat:
        options:
          temperature: 0.3  # 生产环境更保守
          max-tokens: 1024  # 控制成本
logging:
  level:
    root: WARN
    com.example: INFO
app:
  ai:
    rate-limit:
      requests-per-minute: 30  # 生产环境更严格的限流

3. 核心能力深度解析

3.1 大模型调用封装

3.1.1 统一的AI客户端封装设计

在企业级应用中,我们通常需要支持多种AI模型,并且要求有统一的调用接口。Spring AI通过抽象层很好地解决了这个问题。

package com.example.ai.service;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.cache.annotation.Cacheable;
import org.springframework.retry.annotation.Retryable;
import org.springframework.retry.annotation.Backoff;

import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.time.Duration;

/**
 * AI模型服务统一封装
 * 提供多模型支持、缓存、重试等企业级特性
 */
@Service
public class AIModelService {
    
    @Autowired
    private ChatClient primaryChatClient;
    @Autowired
    private ChatClient fallbackChatClient;
    @Autowired
    private AIMetricsService metricsService;
    
    /**
     * 同步聊天调用
     * @param prompt 用户输入
     * @param context 上下文参数
     * @return AI响应
     */
    @Retryable(
        value = {Exception.class},
        maxAttempts = 3,
        backoff = @Backoff(delay = 1000, multiplier = 2)
    )
    @Cacheable(value = "ai-responses", key = "#prompt + #context.hashCode()")
    public String chat(String prompt, Map<String, Object> context) {
        long startTime = System.currentTimeMillis();
        
        try {
            // 构建提示词
            Prompt aiPrompt = buildPrompt(prompt, context);
            
            // 调用主模型
            ChatResponse response = primaryChatClient.call(aiPrompt);
            
            // 记录指标
            metricsService.recordSuccess("primary", 
                System.currentTimeMillis() - startTime);
            
            return response.getResult().getOutput().getContent();
            
        } catch (Exception e) {
            // 降级到备用模型
            return fallbackChat(prompt, context, e);
        }
    }
    
    /**
     * 异步聊天调用
     * @param prompt 用户输入
     * @param context 上下文参数
     * @return 异步响应
     */
    public CompletableFuture<String> chatAsync(String prompt, Map<String, Object> context) {
        return CompletableFuture.supplyAsync(() -> chat(prompt, context));
    }
    
    /**
     * 流式聊天调用
     * @param prompt 用户输入
     * @param context 上下文参数
     * @param callback 流式回调
     */
    public void chatStream(String prompt, Map<String, Object> context, 
                          StreamCallback callback) {
        Prompt aiPrompt = buildPrompt(prompt, context);
        
        primaryChatClient.stream(aiPrompt)
            .subscribe(
                chatResponse -> {
                    String content = chatResponse.getResult().getOutput().getContent();
                    callback.onNext(content);
                },
                error -> {
                    callback.onError(error);
                    metricsService.recordError("primary", error);
                },
                () -> callback.onComplete()
            );
    }
    
    /**
     * 备用模型调用
     */
    private String fallbackChat(String prompt, Map<String, Object> context, Exception originalError) {
        long startTime = System.currentTimeMillis();
        
        try {
            metricsService.recordFallback("primary", originalError);
            
            Prompt aiPrompt = buildPrompt(prompt, context);
            ChatResponse response = fallbackChatClient.call(aiPrompt);
            
            metricsService.recordSuccess("fallback", 
                System.currentTimeMillis() - startTime);
            
            return response.getResult().getOutput().getContent();
            
        } catch (Exception fallbackError) {
            metricsService.recordError("fallback", fallbackError);
            throw new AIServiceException("Both primary and fallback models failed", 
                originalError, fallbackError);
        }
    }
    
    /**
     * 构建提示词
     */
    private Prompt buildPrompt(String userInput, Map<String, Object> context) {
        // 这里可以根据context动态构建提示词模板
        if (context.containsKey("template")) {
            PromptTemplate template = new PromptTemplate((String) context.get("template"));
            return template.create(context);
        } else {
            return new Prompt(List.of(new UserMessage(userInput)));
        }
    }
    
    /**
     * 流式回调接口
     */
    public interface StreamCallback {
        void onNext(String content);
        void onError(Throwable error);
        void onComplete();
    }
}

3.1.2 多模型适配器模式实现

适配器模式让我们能够统一不同AI服务提供商的接口差异。

graph TD
    subgraph Your Application
        A[MyService] --> B{ChatClient};
    end

    subgraph Spring AI Core
        B -- uses --> C[ChatModel];
    end
    
    subgraph Spring AI Adapters
        C -- can be implemented by --> D[OpenAiChatModel];
        C -- can be implemented by --> E[VertexAiGeminiChatModel];
        C -- can be implemented by --> F[OllamaChatModel];
    end

    A --> |"Tell me a joke"| B;
    B --> |Generates a Prompt object| C;
    C --> |Calls external API| D;
    D --> |Sends JSON to OpenAI API| G[OpenAI API];
    G --> |Returns JSON response| D;
    D --> |Parses to ChatResponse| C;
    C --> |Returns ChatResponse| B;
    B --> |Extracts content| A;
package com.example.ai.adapter;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;

/**
 * 模型适配器接口
 * 统一不同AI服务提供商的调用方式
 */
public interface ModelAdapter {
    
    /**
     * 获取模型名称
     */
    String getModelName();
    
    /**
     * 获取提供商名称
     */
    String getProvider();
    
    /**
     * 调用模型
     */
    ChatResponse call(Prompt prompt);
    
    /**
     * 检查模型是否可用
     */
    boolean isAvailable();
    
    /**
     * 获取模型配置
     */
    ModelConfig getConfig();
}
package com.example.ai.adapter;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.stereotype.Component;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;

import lombok.extern.slf4j.Slf4j;

@Slf4j
@Component
@Qualifier("openaiAdapter")
public class OpenAIAdapter implements ModelAdapter {
    
    @Autowired
    private OpenAiChatClient chatClient;
    @Autowired
    private ModelConfig config;

    
    @Override
    public String getModelName() {
        return config.getModelName();
    }
    
    @Override
    public String getProvider() {
        return config.getProvider();
    }
    
    @Override
    public ChatResponse call(Prompt prompt) {
        try {
            log.debug("Calling OpenAI model with prompt: {}", prompt);
            return chatClient.call(prompt);
        } catch (Exception e) {
            log.error("OpenAI model call failed", e);
            throw new ModelCallException("OpenAI call failed", e);
        }
    }
    
    @Override
    public boolean isAvailable() {
        try {
            // 简单的健康检查
            Prompt healthCheck = new Prompt("Hello");
            ChatResponse response = chatClient.call(healthCheck);
            return response != null && response.getResult() != null;
        } catch (Exception e) {
            log.warn("OpenAI health check failed", e);
            return false;
        }
    }
    
    @Override
    public ModelConfig getConfig() {
        return config;
    }
}

3.1.3 异步调用和连接池管理

在高并发场景下,异步调用和连接池管理至关重要。

package com.example.ai.config;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

import java.util.concurrent.Executor;
import java.util.concurrent.ThreadPoolExecutor;

@Configuration
@EnableAsync
public class AsyncConfig {
    
    @Bean(name = "aiTaskExecutor")
    public Executor aiTaskExecutor() {
        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
        
        // 核心线程数
        executor.setCorePoolSize(10);
        
        // 最大线程数
        executor.setMaxPoolSize(50);
        
        // 队列容量
        executor.setQueueCapacity(200);
        
        // 线程名前缀
        executor.setThreadNamePrefix("AI-Task-");
        
        // 拒绝策略:调用者运行
        executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
        
        // 等待所有任务完成后关闭
        executor.setWaitForTasksToCompleteOnShutdown(true);
        
        // 等待时间
        executor.setAwaitTerminationSeconds(60);
        
        executor.initialize();
        return executor;
    }
}
package com.example.ai.service;

import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;

import java.util.concurrent.CompletableFuture;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import lombok.extern.slf4j.Slf4j;

@Slf4j
@Service
public class AsyncAIService {
    
    @Autowired
    private AIModelService aiModelService;
    
    /**
     * 异步批量处理
     * @param prompts 提示词列表
     * @param context 上下文
     * @return 异步结果列表
     */
    @Async("aiTaskExecutor")
    public CompletableFuture<List<String>> batchProcess(
            List<String> prompts, Map<String, Object> context) {
        
        log.info("Starting batch processing for {} prompts", prompts.size());
        
        List<CompletableFuture<String>> futures = prompts.stream()
            .map(prompt -> aiModelService.chatAsync(prompt, context))
            .collect(Collectors.toList());
        
        // 等待所有任务完成
        CompletableFuture<Void> allOf = CompletableFuture.allOf(
            futures.toArray(new CompletableFuture[0]));
        
        return allOf.thenApply(v -> 
            futures.stream()
                .map(CompletableFuture::join)
                .collect(Collectors.toList())
        ).whenComplete((result, throwable) -> {
            if (throwable != null) {
                log.error("Batch processing failed", throwable);
            } else {
                log.info("Batch processing completed successfully, {} results", result.size());
            }
        });
    }
    
    /**
     * 异步单个处理
     * @param prompt 提示词
     * @param context 上下文
     * @return 异步结果
     */
    @Async("aiTaskExecutor")
    public CompletableFuture<String> process(String prompt, Map<String, Object> context) {
        log.info("Starting async processing for prompt: {}", prompt);
        
        return aiModelService.chatAsync(prompt, context)
            .whenComplete((result, throwable) -> {
                if (throwable != null) {
                    log.error("Async processing failed", throwable);
                } else {
                    log.info("Async processing completed successfully");
                }
            });
    }
}

3.1.4 错误处理与重试机制

在生产环境中,AI服务可能会因为网络波动、服务限流等原因导致临时失败。Spring AI结合Spring Retry提供了强大的重试机制。

package com.example.ai.config;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.retry.annotation.EnableRetry;
import org.springframework.retry.backoff.ExponentialBackOffPolicy;
import org.springframework.retry.policy.SimpleRetryPolicy;
import org.springframework.retry.support.RetryTemplate;

import java.util.HashMap;
import java.util.Map;

@Configuration
@EnableRetry
public class RetryConfig {
    
    @Bean
    public RetryTemplate retryTemplate() {
        RetryTemplate retryTemplate = new RetryTemplate();
        
        // 指数退避策略
        ExponentialBackOffPolicy backOffPolicy = new ExponentialBackOffPolicy();
        backOffPolicy.setInitialInterval(1000); // 初始间隔1秒
        backOffPolicy.setMultiplier(2.0);      // 每次加倍
        backOffPolicy.setMaxInterval(10000);   // 最大间隔10秒
        retryTemplate.setBackOffPolicy(backOffPolicy);
        
        // 重试策略:针对网络相关异常进行重试
        Map<Class<? extends Throwable>, Boolean> retryableExceptions = new HashMap<>();
        retryableExceptions.put(java.net.ConnectException.class, true);
        retryableExceptions.put(java.net.SocketTimeoutException.class, true);
        retryableExceptions.put(java.io.IOException.class, true);
        retryableExceptions.put(org.springframework.web.client.ResourceAccessException.class, true);
        
        SimpleRetryPolicy retryPolicy = new SimpleRetryPolicy(3, retryableExceptions, true);
        retryTemplate.setRetryPolicy(retryPolicy);
        
        return retryTemplate;
    }
}

3.1.5 弹性模式实现

结合断路器、速率限制和超时控制,实现更加健壮的AI服务调用。

package com.example.ai.service;

import io.github.resilience4j.circuitbreaker.annotation.CircuitBreaker;
import io.github.resilience4j.ratelimiter.annotation.RateLimiter;
import io.github.resilience4j.timelimiter.annotation.TimeLimiter;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;

import java.util.Map;
import java.util.concurrent.CompletableFuture;

@Service
public class ResilientAIService {
    
    @Autowired
    private AIModelService aiModelService;
    @Autowired
    private AIModelService fallbackService;
    
    /**
     * 弹性AI调用
     * - 断路器:连续失败后快速失败
     * - 速率限制:控制调用频率
     * - 时间限制:控制超时时间
     */
    @CircuitBreaker(name = "aiService", fallbackMethod = "fallbackChat")
    @RateLimiter(name = "aiService")
    @TimeLimiter(name = "aiService")
    public CompletableFuture<String> resilientChat(String prompt, Map<String, Object> context) {
        return CompletableFuture.supplyAsync(() -> {
            return aiModelService.chat(prompt, context);
        });
    }
    
    /**
     * 降级方法
     */
    public CompletableFuture<String> fallbackChat(
            String prompt, Map<String, Object> context, Throwable throwable) {
        return CompletableFuture.supplyAsync(() -> {
            // 记录原始错误
            LoggingService.logError("Primary AI service failed", throwable);
            
            try {
                // 尝试使用备用服务
                return fallbackService.chat(prompt, context);
            } catch (Exception e) {
                // 如果备用服务也失败,返回预设的降级响应
                LoggingService.logError("Fallback AI service also failed", e);
                return "I'm sorry, but I'm currently experiencing technical difficulties. "
                     + "Please try again later or contact support if the issue persists.";
            }
        });
    }
}

3.2 提示工程实践

3.2.1 提示模板管理

Spring AI提供了强大的提示模板管理功能,可以帮助我们组织和复用提示模板。

package com.example.ai.prompt;

import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 提示模板构建器
 * 用于构建复杂的提示模板
 */
public class PromptTemplateBuilder {
    
    private String systemMessage;
    private final List<String> examples = new ArrayList<>();
    private final Map<String, Object> variables = new HashMap<>();
    
    /**
     * 设置系统消息
     */
    public PromptTemplateBuilder withSystemMessage(String systemMessage) {
        this.systemMessage = systemMessage;
        return this;
    }
    
    /**
     * 添加示例
     */
    public PromptTemplateBuilder addExample(String example) {
        this.examples.add(example);
        return this;
    }
    
    /**
     * 添加变量
     */
    public PromptTemplateBuilder addVariable(String key, Object value) {
        this.variables.put(key, value);
        return this;
    }
    
    /**
     * 添加多个变量
     */
    public PromptTemplateBuilder addVariables(Map<String, Object> variables) {
        this.variables.putAll(variables);
        return this;
    }
    
    /**
     * 构建提示模板
     */
    public Prompt build(String templateContent) {
        List<Message> messages = new ArrayList<>();
        
        // 添加系统消息
        if (systemMessage != null && !systemMessage.isEmpty()) {
            messages.add(new SystemMessage(systemMessage));
        }
        
        // 添加示例
        for (String example : examples) {
            messages.add(new UserMessage(example));
        }
        
        // 创建模板并填充变量
        PromptTemplate template = new PromptTemplate(templateContent);
        messages.add(new UserMessage(template.render(variables)));
        
        return new Prompt(messages);
    }
}

3.2.2 提示模板管理器

package com.example.ai.prompt;

import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.stereotype.Component;
import org.springframework.beans.factory.annotation.Autowired;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 提示模板管理器
 * 负责加载和缓存提示模板
 */
@Component
public class PromptTemplateManager {
    
    @Autowired
    private ResourceLoader resourceLoader;
    private final Map<String, PromptTemplate> templateCache = new ConcurrentHashMap<>();  
    
    /**
     * 从类路径资源加载模板
     */
    public PromptTemplate loadTemplate(String templatePath) {
        return templateCache.computeIfAbsent(templatePath, path -> {
            try {
                Resource resource = resourceLoader.getResource("classpath:" + path);
                String content = new String(
                    resource.getInputStream().readAllBytes(), 
                    StandardCharsets.UTF_8
                );
                return new PromptTemplate(content);
            } catch (IOException e) {
                throw new RuntimeException("Failed to load template: " + path, e);
            }
        });
    }
    
    /**
     * 获取代码审查提示模板
     */
    public PromptTemplate getCodeReviewTemplate() {
        return loadTemplate("templates/code-review.txt");
    }
    
    /**
     * 获取文档QA提示模板
     */
    public PromptTemplate getDocumentQATemplate() {
        return loadTemplate("templates/document-qa.txt");
    }
    
    /**
     * 获取客户服务提示模板
     */
    public PromptTemplate getCustomerServiceTemplate() {
        return loadTemplate("templates/customer-service.txt");
    }
}

3.2.3 动态提示生成

package com.example.ai.prompt;

import org.springframework.stereotype.Component;

import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;
import java.util.Map;

/**
 * 动态提示生成器
 * 根据用户角色和场景生成上下文提示
 */
@Component
public class DynamicPromptGenerator {
    
    /**
     * 生成上下文感知的提示
     */
    public String generateContextualPrompt(String basePrompt, String userRole, String scenario) {
        StringBuilder promptBuilder = new StringBuilder();
        
        // 添加时间上下文
        promptBuilder.append("Current time: ")
                     .append(LocalDateTime.now().format(
                         DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")))
                     .append("\n\n");
        
        // 添加角色上下文
        if (userRole != null && !userRole.isEmpty()) {
            promptBuilder.append("You are interacting with a user who has the role: ")
                         .append(userRole)
                         .append("\n");
        }
        
        // 添加场景上下文
        if (scenario != null && !scenario.isEmpty()) {
            promptBuilder.append("The current scenario is: ")
                         .append(scenario)
                         .append("\n\n");
        }
        
        // 添加基础提示
        promptBuilder.append(basePrompt);
        
        // 添加样式修饰符
        promptBuilder.append("\n\nPlease provide a response that is professional, "
                           + "concise, and directly addresses the query above.");
        
        return promptBuilder.toString();
    }
    
    /**
     * 生成带有格式要求的提示
     */
    public String generateFormattedPrompt(String basePrompt, String outputFormat) {
        StringBuilder promptBuilder = new StringBuilder(basePrompt);
        
        promptBuilder.append("\n\nPlease format your response as ")
                     .append(outputFormat);
        
        return promptBuilder.toString();
    }
}

3.3 结构化输出处理

3.3.1 输出模式验证

package com.example.ai.output;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.networknt.schema.JsonSchema;
import com.networknt.schema.JsonSchemaFactory;
import com.networknt.schema.SpecVersion;
import com.networknt.schema.ValidationMessage;
import org.springframework.stereotype.Component;

import java.io.IOException;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * 输出模式验证器
 * 验证AI输出是否符合预期的JSON模式
 */
@Component
public class OutputSchemaValidator {
    
    private final ObjectMapper objectMapper = new ObjectMapper();
    
    /**
     * 验证JSON输出是否符合模式
     * @param jsonOutput AI生成的JSON输出
     * @param schemaJson JSON模式定义
     * @return 验证结果
     */
    public ValidationResult validate(String jsonOutput, String schemaJson) {
        try {
            // 解析JSON输出和模式
            JsonNode outputNode = objectMapper.readTree(jsonOutput);
            JsonNode schemaNode = objectMapper.readTree(schemaJson);
            
            // 创建模式验证器
            JsonSchemaFactory factory = JsonSchemaFactory.getInstance(SpecVersion.VersionFlag.V7);
            JsonSchema schema = factory.getSchema(schemaNode);
            
            // 执行验证
            Set<ValidationMessage> validationMessages = schema.validate(outputNode);
            
            if (validationMessages.isEmpty()) {
                return new ValidationResult(true, null);
            } else {
                String errorDetails = validationMessages.stream()
                    .map(ValidationMessage::getMessage)
                    .collect(Collectors.joining("; "));
                return new ValidationResult(false, errorDetails);
            }
            
        } catch (IOException e) {
            return new ValidationResult(false, "Invalid JSON format: " + e.getMessage());
        }
    }
    
    /**
     * 验证结果类
     */
    public static class ValidationResult {
        private final boolean valid;
        private final String errorDetails;
        
        public ValidationResult(boolean valid, String errorDetails) {
            this.valid = valid;
            this.errorDetails = errorDetails;
        }
        
        public boolean isValid() {
            return valid;
        }
        
        public String getErrorDetails() {
            return errorDetails;
        }
    }
}

3.3.2 输出转换与映射

package com.example.ai.output;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.stereotype.Component;

import java.io.IOException;

/**
 * 输出转换器
 * 将AI输出转换为Java对象
 */
@Component
public class OutputConverter {
    
    private final ObjectMapper objectMapper = new ObjectMapper();
    
    /**
     * 将AI响应转换为指定类型的对象
     * @param response AI响应
     * @param targetClass 目标类型
     * @return 转换后的对象
     */
    public <T> T convertToObject(ChatResponse response, Class<T> targetClass) {
        String content = response.getResult().getOutput().getContent();
        return convertJsonToObject(content, targetClass);
    }
    
    /**
     * 将JSON字符串转换为指定类型的对象
     * @param jsonContent JSON字符串
     * @param targetClass 目标类型
     * @return 转换后的对象
     */
    public <T> T convertJsonToObject(String jsonContent, Class<T> targetClass) {
        try {
            return objectMapper.readValue(jsonContent, targetClass);
        } catch (IOException e) {
            throw new OutputConversionException(
                "Failed to convert output to " + targetClass.getSimpleName(), e);
        }
    }
}

3.4 会话记忆管理

会话记忆的核心是将历史对话(用户输入和AI回复)存储起来,并在下一次请求时附加到提示词中。

sequenceDiagram
    participant User
    participant App as Spring Boot App
    participant ChatMemory as Redis/DB
    participant LLM as Large Language Model

    User->>+App: "你好,我叫张三"
    App->>LLM: Prompt: "你好,我叫张三"
    LLM->>App: "你好张三!有什么可以帮你的吗?"
    App->>+ChatMemory: Store [User: "你好,我叫张三", AI: "你好张三!..."] for session_id_123
    App->>-User: "你好张三!有什么可以帮你的吗?"

    User->>+App: "我叫什么名字?"
    App->>+ChatMemory: Retrieve history for session_id_123
    ChatMemory->>-App: [User: "你好,我叫张三", AI: "你好张三!..."]
    App->>LLM: Prompt: "[History]\nUser: 你好,我叫张三\nAI: 你好张三!...\n\n[Current]\nUser: 我叫什么名字?"
    LLM->>App: "你叫张三。"
    App->>+ChatMemory: Store [User: "我叫什么名字?", AI: "你叫张三。"] for session_id_123
    App->>-User: "你叫张三。"

3.4.1 会话状态管理

package com.example.ai.session;

import org.springframework.ai.chat.messages.Message;
import org.springframework.data.annotation.Id;
import org.springframework.data.redis.core.RedisHash;

import java.io.Serializable;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;

/**
 * 会话状态
 * 存储用户与AI的对话历史
 */
@RedisHash("ai_sessions")
public class SessionState implements Serializable {
    
    @Id
    private String sessionId;
    
    private String userId;
    
    private List<MessageEntry> messageHistory = new ArrayList<>();
    
    private Instant createdAt;
    
    private Instant lastUpdatedAt;
    
    private int tokenCount;
    
    private Map<String, Object> metadata = new HashMap<>();
    
    // 构造函数、getter和setter省略
    
    /**
     * 添加消息到历史记录
     */
    public void addMessage(Message message, int tokens) {
        MessageEntry entry = new MessageEntry(
            message.getType().toString(),
            message.getContent(),
            Instant.now(),
            tokens
        );
        
        messageHistory.add(entry);
        tokenCount += tokens;
        lastUpdatedAt = Instant.now();
    }
    
    /**
     * 获取最近的N条消息
     */
    public List<MessageEntry> getRecentMessages(int count) {
        int startIndex = Math.max(0, messageHistory.size() - count);
        return messageHistory.subList(startIndex, messageHistory.size());
    }
    
    /**
     * 清除超过指定数量的旧消息
     */
    public void pruneHistory(int maxMessages) {
        if (messageHistory.size() > maxMessages) {
            int removeCount = messageHistory.size() - maxMessages;
            List<MessageEntry> prunedMessages = new ArrayList<>(messageHistory.subList(removeCount, messageHistory.size()));
            messageHistory = prunedMessages;
        }
    }
    
    /**
     * 消息条目
     */
    public static class MessageEntry implements Serializable {
        private String type;
        private String content;
        private Instant timestamp;
        private int tokens;
        
        // 构造函数、getter和setter省略
    }
}

3.4.2 会话管理服务

package com.example.ai.session;

import org.springframework.ai.chat.messages.Message;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;

import java.time.Instant;
import java.util.List;
import java.util.Optional;

/**
 * 会话管理服务
 * 管理用户与AI的对话会话
 */
@Service
public class SessionManager {
    
    @Autowired
    private SessionRepository sessionRepository;
    @Autowired
    private TokenCounter tokenCounter;
    
    /**
     * 创建新会话
     */
    public SessionState createSession(String userId) {
        SessionState session = new SessionState();
        session.setSessionId(generateSessionId());
        session.setUserId(userId);
        session.setCreatedAt(Instant.now());
        session.setLastUpdatedAt(Instant.now());
        
        return sessionRepository.save(session);
    }
    
    /**
     * 获取会话
     */
    public Optional<SessionState> getSession(String sessionId) {
        return sessionRepository.findById(sessionId);
    }
    
    /**
     * 添加消息到会话
     */
    public void addMessageToSession(String sessionId, Message message) {
        sessionRepository.findById(sessionId).ifPresent(session -> {
            int tokens = tokenCounter.countTokens(message.getContent());
            session.addMessage(message, tokens);
            sessionRepository.save(session);
        });
    }
    
    /**
     * 获取会话历史消息
     */
    public List<Message> getSessionHistory(String sessionId, int maxMessages) {
        return sessionRepository.findById(sessionId)
            .map(session -> convertToMessages(session.getRecentMessages(maxMessages)))
            .orElse(List.of());
    }
    
    /**
     * 清理过期会话
     */
    public void cleanupExpiredSessions(int maxAgeHours) {
        Instant cutoffTime = Instant.now().minusSeconds(maxAgeHours * 3600);
        List<SessionState> expiredSessions = sessionRepository.findByLastUpdatedAtBefore(cutoffTime);
        sessionRepository.deleteAll(expiredSessions);
    }
    
    // 辅助方法省略
}

3.5 RAG检索增强

3.5.1 向量存储服务

package com.example.ai.rag;

import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;

import java.util.List;

/**
 * 向量存储服务
 * 管理文档的嵌入和检索
 */
@Service
public class VectorStoreService {
    
    @Autowired
    private VectorStore vectorStore;
    @Autowired
    private EmbeddingClient embeddingClient;
    
    /**
     * 添加文档到向量存储
     */
    public void addDocuments(List<Document> documents) {
        vectorStore.add(documents);
    }
    
    /**
     * 根据查询检索相关文档
     */
    public List<Document> search(String query, int k) {
        return vectorStore.similaritySearch(query, k);
    }
    
    /**
     * 删除文档
     */
    public void removeDocuments(List<String> ids) {
        vectorStore.delete(ids);
    }
}

3.5.2 文档处理与分块

package com.example.ai.rag;

import org.springframework.ai.document.Document;
import org.springframework.ai.reader.pdf.PagePdfDocumentReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.stereotype.Component;

import java.io.File;
import java.util.List;

/**
 * 文档处理器
 * 负责文档的读取、分块和处理
 */
@Component
public class DocumentProcessor {
    
    /**
     * 处理PDF文档
     */
    public List<Document> processPdfDocument(File pdfFile, String documentId) {
        // 创建PDF阅读器
        PagePdfDocumentReader reader = new PagePdfDocumentReader(pdfFile);
        
        // 读取文档
        List<Document> documents = reader.get();
        
        // 为文档添加元数据
        documents.forEach(doc -> {
            doc.getMetadata().put("source", pdfFile.getName());
            doc.getMetadata().put("documentId", documentId);
        });
        
        // 创建文本分割器
        TokenTextSplitter splitter = new TokenTextSplitter();
        splitter.setKeepSeparator(true);
        splitter.setMaxTokens(512);
        splitter.setMinTokens(64);
        
        // 分割文档
        return splitter.apply(documents);
    }
}

3.5.3 RAG查询增强

package com.example.ai.rag;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.document.Document;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;

import java.util.ArrayList;
import java.util.List;

/**
 * RAG服务
 * 实现检索增强生成
 */
@Service
public class RAGService {
    
    @Autowired
    private VectorStoreService vectorStoreService;
    @Autowired
    private ChatClient chatClient;
    
    /**
     * 执行RAG查询
     */
    public String query(String userQuery) {
        // 检索相关文档
        List<Document> relevantDocs = vectorStoreService.search(userQuery, 3);
        
        // 构建上下文
        String context = buildContext(relevantDocs);
        
        // 构建提示
        List<Message> messages = new ArrayList<>();
        messages.add(new SystemMessage(
            "You are a helpful assistant. Use the following context to answer the user's question. " +
            "If the answer cannot be found in the context, say 'I don't have enough information to answer that.'\n\n" +
            "Context:\n" + context
        ));
        messages.add(new UserMessage(userQuery));
        
        // 调用AI模型
        Prompt prompt = new Prompt(messages);
        return chatClient.call(prompt).getResult().getOutput().getContent();
    }
    
    /**
     * 构建上下文
     */
    private String buildContext(List<Document> documents) {
        StringBuilder contextBuilder = new StringBuilder();
        
        for (Document doc : documents) {
            contextBuilder.append("--- Document: ")
                         .append(doc.getMetadata().get("source"))
                         .append(" ---\n")
                         .append(doc.getContent())
                         .append("\n\n");
        }
        
        return contextBuilder.toString();
    }
}

3.5.4 增量学习与知识库更新

package com.example.ai.rag;

import org.springframework.ai.document.Document;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;

import java.io.File;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 知识库更新服务
 * 定期检查并更新知识库
 */
@Service
public class KnowledgeBaseUpdater {
    
    @Autowired
    private VectorStoreService vectorStoreService;
    @Autowired
    private DocumentProcessor documentProcessor;
    private final Set<String> processedFiles = ConcurrentHashMap.newKeySet();
    
    // ... existing code ...
    
    /**
     * 定期扫描文档目录并更新知识库
     */
    @Scheduled(fixedDelayString = "${app.ai.rag.update-interval:3600000}")
    public void updateKnowledgeBase() {
        Path docsDirectory = Paths.get("docs");
        File directory = docsDirectory.toFile();
        
        if (directory.exists() && directory.isDirectory()) {
            scanDirectory(directory);
        }
    }
    
    /**
     * 扫描目录中的文档
     */
    private void scanDirectory(File directory) {
        File[] files = directory.listFiles();
        if (files == null) return;
        
        for (File file : files) {
            if (file.isDirectory()) {
                scanDirectory(file);
            } else if (isPdfFile(file) && !isProcessed(file)) {
                processFile(file);
                markAsProcessed(file);
            }
        }
    }
    
    /**
     * 处理文件并添加到向量存储
     */
    private void processFile(File file) {
        try {
            String documentId = file.getAbsolutePath();
            List<Document> chunks = documentProcessor.processPdfDocument(file, documentId);
            vectorStoreService.addDocuments(chunks);
        } catch (Exception e) {
            // 记录错误并继续处理其他文件
        }
    }
    
    private boolean isPdfFile(File file) {
        return file.getName().toLowerCase().endsWith(".pdf");
    }
    
    private boolean isProcessed(File file) {
        return processedFiles.contains(file.getAbsolutePath() + "_" + file.lastModified());
    }
    
    private void markAsProcessed(File file) {
        processedFiles.add(file.getAbsolutePath() + "_" + file.lastModified());
    }
}

3.6 工具调用集成

3.6.1 工具定义与注册

package com.example.ai.tool;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * 工具注解
 * 用于标记可被AI调用的工具方法
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface AITool {
    
    /**
     * 工具名称
     */
    String name();
    
    /**
     * 工具描述
     */
    String description();
    
    /**
     * 工具类别
     */
    ToolCategory category() default ToolCategory.GENERAL;
    
    /**
     * 工具类别枚举
     */
    enum ToolCategory {
        GENERAL,
        DATA_ANALYSIS,
        EXTERNAL_API,
        SYSTEM_OPERATION
    }
}
package com.example.ai.tool;

import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 工具注册表
 * 管理和执行AI可调用的工具
 */
@Service
public class ToolRegistry {
    
    private final Map<String, ToolDefinition> tools = new HashMap<>();
    @Autowired
    private OpenAiChatClient chatClient;
    
    // ... existing code ...
    
    /**
     * 注册工具
     */
    public void registerTool(Object instance, Method method, AITool annotation) {
        ToolDefinition tool = new ToolDefinition(
            annotation.name(),
            annotation.description(),
            annotation.category(),
            instance,
            method
        );
        
        tools.put(annotation.name(), tool);
    }
    
    /**
     * 执行工具调用
     */
    public Object executeTool(String toolName, Map<String, Object> parameters) {
        ToolDefinition tool = tools.get(toolName);
        if (tool == null) {
            throw new IllegalArgumentException("Unknown tool: " + toolName);
        }
        
        try {
            return tool.execute(parameters);
        } catch (Exception e) {
            throw new ToolExecutionException("Failed to execute tool: " + toolName, e);
        }
    }
    
    /**
     * 获取工具列表
     */
    public List<Map<String, Object>> getToolDefinitions() {
        List<Map<String, Object>> definitions = new ArrayList<>();
        
        for (ToolDefinition tool : tools.values()) {
            Map<String, Object> definition = new HashMap<>();
            definition.put("type", "function");
            
            Map<String, Object> function = new HashMap<>();
            function.put("name", tool.getName());
            function.put("description", tool.getDescription());
            
            // 这里可以添加参数模式定义
            
            definition.put("function", function);
            definitions.add(definition);
        }
        
        return definitions;
    }
    
    /**
     * 使用工具增强AI响应
     */
    public String getToolEnhancedResponse(String userQuery) {
        // 构建系统提示,告知AI可用的工具
        StringBuilder systemPrompt = new StringBuilder();
        systemPrompt.append("You have access to the following tools:\n\n");
        
        for (ToolDefinition tool : tools.values()) {
            systemPrompt.append("- ").append(tool.getName())
                       .append(": ").append(tool.getDescription())
                       .append("\n");
        }
        
        systemPrompt.append("\nWhen you need to use a tool, respond with JSON in this format:\n")
                   .append("{\"tool\": \"tool_name\", \"parameters\": {\"param1\": \"value1\"}}\n\n")
                   .append("If you don't need to use a tool, just respond normally.");
        
        // 构建提示
        List<Message> messages = new ArrayList<>();
        messages.add(new SystemMessage(systemPrompt.toString()));
        messages.add(new UserMessage(userQuery));
        
        // 调用AI
        Prompt prompt = new Prompt(messages);
        String response = chatClient.call(prompt).getResult().getOutput().getContent();
        
        // 解析响应,检查是否需要调用工具
        try {
            // 简单的JSON检测
            if (response.trim().startsWith("{") && response.contains("\"tool\":")) {
                // 解析工具调用请求并执行
                // 实际实现需要更复杂的JSON解析
                // 这里简化处理
                return "Tool execution result would appear here";
            }
        } catch (Exception e) {
            // 解析失败,返回原始响应
        }
        
        return response;
    }
}

3.6.2 工具实现示例

package com.example.ai.tool.impl;

import com.example.ai.tool.AITool;
import org.springframework.stereotype.Component;

import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;
import java.util.Map;

/**
 * 实用工具集
 * 提供各种辅助功能
 */
@Component
public class UtilityTools {
    
    /**
     * 获取当前时间
     */
    @AITool(
        name = "getCurrentTime",
        description = "Get the current date and time in various formats"
    )
    public Map<String, String> getCurrentTime(String format) {
        LocalDateTime now = LocalDateTime.now();
        Map<String, String> result = new HashMap<>();
        
        if (format == null || format.isEmpty()) {
            format = "yyyy-MM-dd HH:mm:ss";
        }
        
        result.put("formatted", now.format(DateTimeFormatter.ofPattern(format)));
        result.put("iso", now.toString());
        
        return result;
    }
    
    /**
     * 计算器功能
     */
    @AITool(
        name = "calculator",
        description = "Perform basic arithmetic operations"
    )
    public double calculate(String operation, double a, double b) {
        switch (operation.toLowerCase()) {
            case "add":
                return a + b;
            case "subtract":
                return a - b;
            case "multiply":
                return a * b;
            case "divide":
                if (b == 0) throw new IllegalArgumentException("Cannot divide by zero");
                return a / b;
            default:
                throw new IllegalArgumentException("Unknown operation: " + operation);
        }
    }
}

3.7 MCP多模型链路处理

3.7.1 模型链路定义

package com.example.ai.mcp;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

/**
 * 模型处理链
 * 定义AI模型的处理流程
 */
public class ModelChain {
    
    private final List<ChainNode> nodes = new ArrayList<>();
    
    /**
     * 添加处理节点
     */
    public ModelChain addNode(String name, ChatClient client, Function<String, String> promptTransformer) {
        nodes.add(new ChainNode(name, client, promptTransformer));
        return this;
    }
    
    /**
     * 执行处理链
     */
    public ChainResult execute(String initialInput) {
        String currentInput = initialInput;
        List<NodeResult> results = new ArrayList<>();
        
        for (ChainNode node : nodes) {
            // 转换提示
            String transformedPrompt = node.getPromptTransformer().apply(currentInput);
            
            // 调用模型
            Prompt prompt = new Prompt(List.of(new UserMessage(transformedPrompt)));
            String output = node.getClient().call(prompt).getResult().getOutput().getContent();
            
            // 记录结果
            results.add(new NodeResult(node.getName(), transformedPrompt, output));
            
            // 更新输入
            currentInput = output;
        }
        
        return new ChainResult(results);
    }
    
    /**
     * 链节点
     */
    private static class ChainNode {
        private final String name;
        private final ChatClient client;
        private final Function<String, String> promptTransformer;
        
        public ChainNode(String name, ChatClient client, Function<String, String> promptTransformer) {
            this.name = name;
            this.client = client;
            this.promptTransformer = promptTransformer;
        }
        
        public String getName() {
            return name;
        }
        
        public ChatClient getClient() {
            return client;
        }
        
        public Function<String, String> getPromptTransformer() {
            return promptTransformer;
        }
    }
    
    /**
     * 节点结果
     */
    public static class NodeResult {
        private final String nodeName;
        private final String input;
        private final String output;
        
        public NodeResult(String nodeName, String input, String output) {
            this.nodeName = nodeName;
            this.input = input;
            this.output = output;
        }
        
        // Getters省略
    }
    
    /**
     * 链执行结果
     */
    public static class ChainResult {
        private final List<NodeResult> nodeResults;
        
        public ChainResult(List<NodeResult> nodeResults) {
            this.nodeResults = nodeResults;
        }
        
        public String getFinalOutput() {
            if (nodeResults.isEmpty()) {
                return "";
            }
            return nodeResults.get(nodeResults.size() - 1).output;
        }
        
        public List<NodeResult> getAllResults() {
            return nodeResults;
        }
    }
}

3.7.2 多模型协作示例

package com.example.ai.mcp;

import org.springframework.ai.chat.ChatClient;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;

/**
 * 内容生成服务
 * 使用多模型链路生成高质量内容
 */
@Service
public class ContentGenerationService {
    
    @Autowired
    @Qualifier("gpt4ChatClient")
    private ChatClient ideationModel;
    
    @Autowired
    @Qualifier("gpt35ChatClient")
    private ChatClient draftingModel;
    
    @Autowired
    @Qualifier("anthropicChatClient")
    private ChatClient editingModel;
    
    /**
     * 生成文章
     */
    public String generateArticle(String topic, String targetAudience, int wordCount) {
        ModelChain chain = new ModelChain()
            // 创意发想阶段 - 使用高级模型生成大纲
            .addNode("ideation", ideationModel, input -> 
                "Generate a detailed outline for an article about '" + topic + "' " +
                "targeting " + targetAudience + ". " +
                "Include main sections and key points for each section."
            )
            // 草稿撰写阶段 - 使用经济模型生成初稿
            .addNode("drafting", draftingModel, outline -> 
                "Using the following outline, write a " + wordCount + " word article " +
                "about '" + topic + "' for " + targetAudience + ". " +
                "Make it engaging and informative.\n\nOutline:\n" + outline
            )
            // 编辑优化阶段 - 使用另一个模型进行润色
            .addNode("editing", editingModel, draft -> 
                "Edit and improve the following article draft. " +
                "Fix any grammar or style issues, improve flow and clarity, " +
                "and ensure it's engaging for " + targetAudience + ".\n\n" + draft
            );
        
        return chain.execute(topic).getFinalOutput();
    }
}

4. 实际案例展示

4.1 智能客服系统

package com.example.ai.cases.customerservice;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;

import java.util.ArrayList;
import java.util.List;

/**
 * 智能客服服务
 */
@Service
public class CustomerSupportService {
    
    @Autowired
    private ChatClient chatClient;
    @Autowired
    private SessionManager sessionManager;
    @Autowired
    private KnowledgeBaseService knowledgeBaseService;
    
    /**
     * 处理客户查询
     */
    public String handleCustomerQuery(String sessionId, String userQuery) {
        // 获取会话历史
        List<Message> messages = sessionManager.getSessionHistory(sessionId);
        if (messages.isEmpty()) {
            // 新会话,添加系统提示
            messages.add(new SystemMessage(
                "You are a helpful customer support assistant for our product. " +
                "Be polite, concise, and helpful. If you don't know the answer, " +
                "suggest escalating to a human agent."
            ));
        }
        
        // 检索相关知识库内容
        String relevantInfo = knowledgeBaseService.retrieveRelevantInformation(userQuery);
        
        // 添加知识库上下文和用户查询
        if (!relevantInfo.isEmpty()) {
            messages.add(new SystemMessage(
                "Here is some relevant information that might help with the response:\n" + 
                relevantInfo
            ));
        }
        
        // 添加用户查询
        messages.add(new UserMessage(userQuery));
        
        // 调用AI模型
        Prompt prompt = new Prompt(messages);
        String response = chatClient.call(prompt).getResult().getOutput().getContent();
        
        // 更新会话历史
        sessionManager.updateSessionHistory(sessionId, userQuery, response);
        
        return response;
    }
}

4.2 智能文档分析

package com.example.ai.cases.docanalysis;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.document.Document;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;

import java.util.ArrayList;
import java.util.List;

/**
 * 文档分析服务
 */
@Service
public class DocumentAnalysisService {
    
    @Autowired
    private ChatClient chatClient;
    @Autowired
    private DocumentProcessor documentProcessor;
    
    /**
     * 分析文档并生成摘要
     */
    public String generateDocumentSummary(String documentPath) {
        // 处理文档
        List<Document> documentChunks = documentProcessor.processDocument(documentPath);
        
        // 提取文档内容
        StringBuilder contentBuilder = new StringBuilder();
        for (Document chunk : documentChunks) {
            contentBuilder.append(chunk.getContent()).append("\n\n");
        }
        String documentContent = contentBuilder.toString();
        
        // 构建提示
        List<Message> messages = new ArrayList<>();
        messages.add(new SystemMessage(
            "You are a document analysis assistant. Your task is to analyze the provided document " +
            "and generate a comprehensive summary that captures the main points, key findings, " +
            "and important details. The summary should be well-structured and easy to understand."
        ));
        messages.add(new UserMessage(
            "Please analyze the following document and provide a summary:\n\n" + documentContent
        ));
        
        // 调用AI模型
        Prompt prompt = new Prompt(messages);
        return chatClient.call(prompt).getResult().getOutput().getContent();
    }
    
    /**
     * 提取文档中的关键信息
     */
    public List<KeyInsight> extractKeyInsights(String documentPath) {
        // 处理文档
        List<Document> documentChunks = documentProcessor.processDocument(documentPath);
        
        // 提取文档内容
        StringBuilder contentBuilder = new StringBuilder();
        for (Document chunk : documentChunks) {
            contentBuilder.append(chunk.getContent()).append("\n\n");
        }
        String documentContent = contentBuilder.toString();
        
        // 构建提示
        List<Message> messages = new ArrayList<>();
        messages.add(new SystemMessage(
            "Extract key insights from the document in JSON format. " +
            "Each insight should have a 'category', 'title', and 'description'. " +
            "Focus on the most important information."
        ));
        messages.add(new UserMessage(
            "Please extract key insights from the following document:\n\n" + documentContent
        ));
        
        // 调用AI模型
        Prompt prompt = new Prompt(messages);
        String response = chatClient.call(prompt).getResult().getOutput().getContent();
        
        // 解析JSON响应(实际实现需要更复杂的JSON处理)
        // 这里简化处理
        List<KeyInsight> insights = new ArrayList<>();
        // 解析逻辑...
        
        return insights;
    }
    
    /**
     * 关键洞察类
     */
    public static class KeyInsight {
        private String category;
        private String title;
        private String description;
        
        // 构造函数、getter和setter省略
    }
}

5. 性能优化与最佳实践

5.1 性能优化策略

在生产环境中,AI服务的性能优化至关重要。以下是一些关键策略:

  1. 缓存策略:对于相同或相似的查询,使用缓存可以显著减少API调用和响应时间。
package com.example.ai.optimization;

import org.springframework.cache.annotation.Cacheable;
import org.springframework.stereotype.Service;

@Service
public class CachedAIService {
    
    @Cacheable(value = "ai-responses", key = "#prompt + #temperature")
    public String getCachedResponse(String prompt, double temperature) {
        // 实际AI调用逻辑
        return "AI response";
    }
}
  1. 批处理请求:将多个请求合并为一个批处理请求,减少网络往返。
  2. 模型选择:根据任务复杂性选择合适的模型,简单任务使用轻量级模型。
  3. 提示词优化:精简提示词,减少token数量,降低成本和延迟。
  4. 连接池管理:使用连接池复用HTTP连接,减少连接建立开销。

5.2 监控与可观测性

package com.example.ai.monitoring;

import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Timer;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import javax.annotation.PostConstruct;

/**
 * AI监控服务
 */
@Service
public class AIMonitoringService {
    
    @Autowired
    private MeterRegistry meterRegistry;
    private Counter totalRequestsCounter;
    private Counter successRequestsCounter;
    private Counter failedRequestsCounter;
    private Timer responseTimeTimer;
    
    public AIMonitoringService() {
        // 初始化计数器和计时器的代码移到@PostConstruct方法中
    }
    
    @PostConstruct
    public void init() {
        this.totalRequestsCounter = Counter.builder("ai.requests.total")
            .description("Total number of AI requests")
            .register(meterRegistry);
            
        this.successRequestsCounter = Counter.builder("ai.requests.success")
            .description("Number of successful AI requests")
            .register(meterRegistry);
            
        this.failedRequestsCounter = Counter.builder("ai.requests.failed")
            .description("Number of failed AI requests")
            .register(meterRegistry);
            
        this.responseTimeTimer = Timer.builder("ai.response.time")
            .description("AI response time")
            .register(meterRegistry);
    }
    
    /**
     * 记录请求开始
     */
    public void recordRequestStart() {
        totalRequestsCounter.increment();
    }
    
    /**
     * 记录请求成功
     */
    public void recordRequestSuccess(long durationMs) {
        successRequestsCounter.increment();
        responseTimeTimer.record(java.time.Duration.ofMillis(durationMs));
    }
    
    /**
     * 记录请求失败
     */
    public void recordRequestFailure(String errorType) {
        failedRequestsCounter.increment();
        Counter.builder("ai.requests.errors")
            .tag("error.type", errorType)
            .description("AI request errors by type")
            .register(meterRegistry)
            .increment();
    }
}

5.3 成本控制策略

package com.example.ai.cost;

import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Value;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

/**
 * AI成本控制服务
 */
@Service
public class CostControlService {
    
    @Value("${app.ai.cost.token-limit-per-day}")
    private int tokenLimitPerDay;
    
    @Value("${app.ai.cost.token-cost-per-thousand}")
    private double tokenCostPerThousand;
    
    private final AtomicInteger tokensUsedToday = new AtomicInteger(0);
    private final AtomicLong totalCost = new AtomicLong(0);
    private long dayStartTimestamp = System.currentTimeMillis();
    
    /**
     * 检查是否超出每日token限制
     */
    public boolean isWithinDailyLimit(int estimatedTokens) {
        // 检查是否需要重置计数器
        checkAndResetDailyCounter();
        
        return tokensUsedToday.get() + estimatedTokens <= tokenLimitPerDay;
    }
    
    /**
     * 记录token使用量
     */
    public void recordTokenUsage(int promptTokens, int completionTokens) {
        // 检查是否需要重置计数器
        checkAndResetDailyCounter();
        
        // 更新token使用量
        int totalTokens = promptTokens + completionTokens;
        tokensUsedToday.addAndGet(totalTokens);
        
        // 计算成本(美元)
        double cost = (totalTokens / 1000.0) * tokenCostPerThousand;
        totalCost.addAndGet(Math.round(cost * 100)); // 转换为分
    }
    
    /**
     * 获取今日使用的token数量
     */
    public int getTokensUsedToday() {
        checkAndResetDailyCounter();
        return tokensUsedToday.get();
    }
    
    /**
     * 获取总成本(美元)
     */
    public double getTotalCost() {
        return totalCost.get() / 100.0; // 转换回美元
    }
    
    /**
     * 检查并重置每日计数器
     */
    private void checkAndResetDailyCounter() {
        long currentTime = System.currentTimeMillis();
        long oneDayInMillis = 24 * 60 * 60 * 1000;
        
        if (currentTime - dayStartTimestamp > oneDayInMillis) {
            tokensUsedToday.set(0);
            dayStartTimestamp = currentTime;
        }
    }
}

5.4 模型评估与质量控制

package com.example.ai.quality;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 响应质量评估服务
 */
@Service
public class ResponseQualityEvaluator {
    
    @Autowired
    private ChatClient evaluatorModel;
    
    /**
     * 评估AI响应质量
     */
    public QualityScore evaluateResponse(String userQuery, String aiResponse) {
        // 构建评估提示
        List<Message> messages = new ArrayList<>();
        messages.add(new SystemMessage(
            "You are an AI response evaluator. Your task is to evaluate the quality of an AI assistant's response " +
            "to a user query. Rate the response on a scale of 1-10 for each of the following criteria: " +
            "Relevance, Accuracy, Completeness, Clarity, and Helpfulness. " +
            "Provide your ratings in JSON format with a brief explanation for each score."
        ));
        
        messages.add(new UserMessage(
            "User Query: " + userQuery + "\n\n" +
            "AI Response: " + aiResponse + "\n\n" +
            "Please evaluate this response."
        ));
        
        // 调用评估模型
        Prompt prompt = new Prompt(messages);
        String evaluationResult = evaluatorModel.call(prompt).getResult().getOutput().getContent();
        
        // 解析评估结果(实际实现需要更复杂的JSON处理)
        // 这里简化处理
        QualityScore score = new QualityScore();
        // 解析逻辑...
        
        return score;
    }
    
    /**
     * 质量评分类
     */
    public static class QualityScore {
        private Map<String, Integer> scores = new HashMap<>();
        private Map<String, String> explanations = new HashMap<>();
        
        public void addScore(String criterion, int score, String explanation) {
            scores.put(criterion, score);
            explanations.put(criterion, explanation);
        }
        
        public double getAverageScore() {
            if (scores.isEmpty()) {
                return 0.0;
            }
            
            int sum = scores.values().stream().mapToInt(Integer::intValue).sum();
            return (double) sum / scores.size();
        }
        
        public Map<String, Integer> getScores() {
            return scores;
        }
        
        public Map<String, String> getExplanations() {
            return explanations;
        }
    }
}

6. 总结与展望

6.1 Spring AI的未来发展

Spring AI 1.0的发布只是一个开始。随着AI技术的快速发展,我们可以期待Spring AI在以下方向继续演进:

  1. 更多模型支持:除了当前支持的OpenAI、Azure OpenAI和Anthropic模型外,未来版本可能会增加对更多开源和商业模型的支持,如Llama、Mistral、Gemini等。
  2. 多模态能力增强:增强对图像、音频和视频等多模态内容的处理能力,使Spring AI能够处理更复杂的应用场景。
  3. 本地模型部署:提供更便捷的本地模型部署方案,降低API调用成本,增强数据隐私保护。
  4. 更丰富的工具集成:增加更多预构建的工具和插件,简化特定场景下的AI应用开发。
  5. 企业级功能增强:进一步增强安全性、可观测性和治理能力,满足企业级应用的严格要求。

6.2 企业应用最佳实践

在企业环境中应用Spring AI时,建议遵循以下最佳实践:

  1. 分层架构设计:将AI能力封装在独立的服务层,与业务逻辑解耦,便于维护和升级。
  2. 渐进式集成:从小规模、非关键业务场景开始,逐步扩展到更核心的业务流程。
  3. 多模型策略:根据不同任务的复杂度和重要性,选择合适的模型,平衡性能和成本。
  4. 完善的监控体系:建立全面的监控指标,及时发现并解决性能问题和异常情况。
  5. 持续评估与优化:定期评估AI响应质量,根据反馈不断优化提示词和模型配置。
  6. 安全与合规:严格控制敏感数据的使用,确保AI应用符合相关法规和企业政策。

6.3 结语

Spring AI 1.0为Java开发者提供了一个强大而灵活的框架,使企业级AI应用的开发变得更加简单和高效。通过统一的抽象层、丰富的功能集和与Spring生态的无缝集成,它解决了传统AI集成方案中的许多痛点。

随着AI技术的不断发展和Spring AI框架的持续演进,我们有理由相信,基于Spring AI构建的企业级智能应用将在未来发挥越来越重要的作用,为企业创造更大的价值。

无论你是刚开始探索AI应用开发,还是已经有了丰富的经验,Spring AI都能为你提供所需的工具和抽象,帮助你构建下一代智能应用。现在正是开始这一旅程的最佳时机!