Springboot 实现RAG的检索、增强、生成

118 阅读7分钟

1. 配置类(Config层)

java复制代码
/**
 * Milvus向量数据库配置类
 * 功能:创建并配置Milvus客户端连接
 */
@Configuration
public class MilvusConfig {
    @Value("${milvus.host:localhost}") // 从配置读取,默认localhost
    private String host;

    @Value("${milvus.port:19530}") // Milvus默认端口
    private int port;

    /**
     * 创建Milvus客户端Bean
     * 连接协议说明:使用gRPC协议与Milvus服务通信
     */
    @Bean
    public MilvusClient milvusClient() {
        ConnectParam connectParam = ConnectParam.newBuilder()
            .withHost(host)
            .withPort(port)
            .build();
        return new MilvusClient(connectParam);
    }
}

/**
 * 模型服务配置类
 * 功能:配置与模型服务(如FastAPI服务)的HTTP连接
 */
@Configuration
public class ModelServerConfig {
    /**
     * 创建专用RestTemplate实例
     * 配置要点:
     * - 根地址指向模型服务器
     * - 设置合理的超时时间(根据模型推理时间调整)
     */
    @Bean
    public RestTemplate modelRestTemplate(
        @Value("${model.server.url}") String baseUrl) {
        return new RestTemplateBuilder()
            .rootUri(baseUrl)
            .setConnectTimeout(Duration.ofSeconds(30))  // 连接超时30秒
            .setReadTimeout(Duration.ofSeconds(120))    // 读取超时2分钟
            .build();
    }
}

2. 向量化服务(Service层)

java复制代码
/**
 * 文档向量化服务
 * 核心功能:调用远程模型服务将文本转换为向量
 */
@Service
public class VectorizationService {
    private final RestTemplate modelClient;

    public VectorizationService(RestTemplate modelClient) {
        this.modelClient = modelClient;
    }

    /**
     * 文本向量化方法
     * @param text 输入文本(建议长度不超过512 token)
     * @return 归一化后的768维浮点数组
     * 实现原理:
     * 1. 调用模型服务的/vectorize接口
     * 2. 接收返回的JSON格式向量
     * 3. 转换为Java原生数组格式
     */
    public float[] vectorize(String text) {
        // 构建符合模型服务要求的请求体
        Map<String, Object> request = Map.of(
            "text", text,
            "normalize", true  // 请求归一化向量
        );
        
        // 发送POST请求并解析响应
        Map<String, Object> response = modelClient.postForObject(
            "/vectorize", 
            request, 
            Map.class
        );
        
        // 类型转换处理(模型服务返回List<Double>)
        List<Double> embedding = (List<Double>) response.get("embedding");
        return convertToFloatArray(embedding);
    }

    /**
     * 类型转换辅助方法
     * 精度说明:Double转Float会损失精度,但对相似度计算影响可接受
     */
    private float[] convertToFloatArray(List<Double> list) {
        float[] array = new float[list.size()];
        for (int i = 0; i < list.size(); i++) {
            array[i] = list.get(i).floatValue();
        }
        return array;
    }
}

3. Milvus服务(Service层)

java复制代码
/**
 * Milvus向量数据库操作服务
 * 包含功能:集合管理、文档存储、相似性检索
 */
@Service
public class MilvusService {
    private static final String COLLECTION_NAME = "knowledge_base";
    private static final int VECTOR_DIM = 768; // 与模型输出维度一致
    
    private final MilvusClient milvusClient;

    public MilvusService(MilvusClient milvusClient) {
        this.milvusClient = milvusClient;
        initializeCollection();
    }

    /**
     * 初始化集合
     * 设计要点:
     * - 定义三个字段:主键ID、文本内容、向量字段
     * - 配置索引类型为IVF_FLAT(适合精确搜索)
     */
    private void initializeCollection() {
        if (!milvusClient.hasCollection(COLLECTION_NAME)) {
            // 构建字段模式
            FieldType idField = FieldType.newBuilder()
                .withName("id")
                .withDataType(DataType.Int64)
                .withPrimaryKey(true)
                .withAutoID(true)
                .build();

            FieldType contentField = FieldType.newBuilder()
                .withName("content")
                .withDataType(DataType.VarChar)
                .withMaxLength(2000) // 支持长文本存储
                .build();

            FieldType vectorField = FieldType.newBuilder()
                .withName("vector")
                .withDataType(DataType.FloatVector)
                .withDimension(VECTOR_DIM)
                .build();

            // 创建集合
            CreateCollectionParam createParam = CreateCollectionParam.newBuilder()
                .withCollectionName(COLLECTION_NAME)
                .addFieldType(idField)
                .addFieldType(contentField)
                .addFieldType(vectorField)
                .build();

            milvusClient.createCollection(createParam);

            // 创建索引(提升检索效率)
            CreateIndexParam indexParam = CreateIndexParam.newBuilder()
                .withCollectionName(COLLECTION_NAME)
                .withFieldName("vector")
                .withIndexType(IndexType.IVF_FLAT)
                .withMetricType(MetricType.IP)  // 使用内积相似度
                .withExtraParam("{"nlist":1024}")
                .build();
            milvusClient.createIndex(indexParam);
        }
    }

    /**
     * 文档存储方法
     * @param content 文本内容
     * @param vector 对应向量
     * 注意:实际生产环境应添加批量插入接口
     */
    public void storeDocument(String content, float[] vector) {
        List<InsertParam.Field> fields = Arrays.asList(
            new InsertParam.Field("content", Collections.singletonList(content)),
            new InsertParam.Field("vector", Collections.singletonList(vector))
        );

        InsertParam insertParam = InsertParam.newBuilder()
            .withCollectionName(COLLECTION_NAME)
            .withFields(fields)
            .build();

        milvusClient.insert(insertParam);
        milvusClient.flush(COLLECTION_NAME); // 立即刷盘,生产环境可定期批量执行
    }

    /**
     * 相似文档检索
     * @param queryVector 查询向量
     * @param topK 返回结果数量
     * @return 相关文档内容列表
     * 参数说明:
     * - nprobe:搜索的聚类中心数量,平衡精度与速度
     */
    public List<String> retrieveSimilarDocuments(float[] queryVector, int topK) {
        SearchParam searchParam = SearchParam.newBuilder()
            .withCollectionName(COLLECTION_NAME)
            .withVectorFieldName("vector")
            .withVectors(Collections.singletonList(queryVector))
            .withTopK(topK)
            .withParams("{"nprobe": 32}") 
            .build();

        SearchResults results = milvusClient.search(searchParam);
        return parseResults(results);
    }

    /**
     * 解析Milvus返回结果
     * 返回字段说明:
     * - distance:相似度得分
     * - content:文档内容
     */
    private List<String> parseResults(SearchResults results) {
        return results.getResults().stream()
            .sorted((a,b) -> Float.compare(
                (Float)b.get("distance"), // 按相似度降序
                (Float)a.get("distance")))
            .map(entry -> (String) entry.get("content"))
            .collect(Collectors.toList());
    }
}

4. RAG增强服务

java复制代码
/**
 * RAG增强处理服务
 * 实现功能:查询扩展 + 结果重排序
 */
@Service
public class AugmentationService {
    private static final int INITIAL_RETRIEVE_NUM = 20; // 初步召回数量
    private static final int FINAL_TOP_K = 5;          // 最终返回数量

    private final VectorizationService vectorizationService;
    private final MilvusService milvusService;

    public AugmentationService(VectorizationService vectorizationService,
                              MilvusService milvusService) {
        this.vectorizationService = vectorizationService;
        this.milvusService = milvusService;
    }

    /**
     * 增强检索流程:
     * 1. 查询扩展:生成相关查询变体
     * 2. 多向量融合检索
     * 3. 结果重排序
     */
    public List<String> enhancedRetrieval(String originalQuery) {
        // 阶段一:查询扩展
        List<String> expandedQueries = queryExpansion(originalQuery);
        
        // 阶段二:多向量检索
        List<String> initialResults = new ArrayList<>();
        for (String query : expandedQueries) {
            float[] vector = vectorizationService.vectorize(query);
            initialResults.addAll(
                milvusService.retrieveSimilarDocuments(vector, INITIAL_RETRIEVE_NUM)
            );
        }
        
        // 阶段三:去重与重排序
        List<String> uniqueResults = new ArrayList<>(new LinkedHashSet<>(initialResults));
        return rerankResults(originalQuery, uniqueResults);
    }

    /**
     * 查询扩展方法
     * 技术原理:使用LLM生成相关查询变体,提升召回率
     */
    private List<String> queryExpansion(String originalQuery) {
        String prompt = String.format("""
            请生成3个与以下查询语义相似的变体,保持专业表述:
            原始查询:%s
            输出格式:JSON数组,字段为"queries"
            """, originalQuery);

        Map<String, Object> request = Map.of(
            "prompt", prompt,
            "max_tokens", 100
        );
        
        Map<String, Object> response = vectorizationService.getModelClient()
            .postForObject("/generate", request, Map.class);
        
        // 解析JSON响应,此处简化处理
        return Arrays.asList(originalQuery, "扩展查询1", "扩展查询2");
    }

    /**
     * 结果重排序方法
     * 技术原理:使用交叉编码器计算query-doc相关性
     */
    private List<String> rerankResults(String query, List<String> candidates) {
        Map<String, Object> rerankRequest = Map.of(
            "query", query,
            "documents", candidates
        );
        
        List<Double> scores = vectorizationService.getModelClient()
            .postForObject("/rerank", rerankRequest, List.class);
        
        // 创建带分数的结果列表
        List<Pair<String, Double>> scoredDocs = new ArrayList<>();
        for (int i = 0; i < candidates.size(); i++) {
            scoredDocs.add(new Pair<>(candidates.get(i), scores.get(i)));
        }
        
        // 按分数降序排序
        scoredDocs.sort((a,b) -> Double.compare(b.getValue(), a.getValue()));
        
        return scoredDocs.stream()
            .map(Pair::getKey)
            .limit(FINAL_TOP_K)
            .collect(Collectors.toList());
    }
}

5. 生成服务

java复制代码
/**
 * 答案生成服务
 * 功能:整合检索结果生成最终答案
 */
@Service
public class GenerationService {
    private final RestTemplate modelClient;

    public GenerationService(RestTemplate modelClient) {
        this.modelClient = modelClient;
    }

    /**
     * 生成最终答案
     * @param question 用户问题
     * @param contexts 相关上下文
     * @return 结构化答案
     * 提示工程技巧:
     * - 明确要求答案格式
     * - 限制幻觉生成
     */
    public String generateAnswer(String question, List<String> contexts) {
        String contextStr = contexts.stream()
            .limit(3) // 取最相关的前3个上下文
            .collect(Collectors.joining("\n\n"));
        
        String prompt = buildPrompt(question, contextStr);
        
        Map<String, Object> request = Map.of(
            "prompt", prompt,
            "temperature", 0.3,  // 控制生成随机性
            "max_tokens", 500
        );
        
        Map<String, Object> response = modelClient.postForObject(
            "/generate", 
            request, 
            Map.class
        );
        
        return postProcessAnswer((String) response.get("generated_text"));
    }

    /**
     * 构建提示模板
     * 设计要点:
     * - 明确角色设定
     * - 强调基于上下文
     * - 要求结构化输出
     */
    private String buildPrompt(String question, String context) {
        return String.format("""
            你是一个专业的知识问答助手,请严格根据提供的上下文信息回答问题。
            
            # 上下文:
            %s
            
            # 问题:
            %s
            
            # 要求:
            1. 如果上下文包含明确答案,用简洁的Markdown格式回答
            2. 如果信息不足,回答“根据现有资料无法确定”
            3. 禁止编造不存在的信息
            """, context, question);
    }

    /**
     * 后处理方法
     * 功能:过滤敏感词、移除重复内容等
     */
    private String postProcessAnswer(String rawAnswer) {
        // 示例:移除可能的重复段落
        return Arrays.stream(rawAnswer.split("\n"))
            .distinct()
            .collect(Collectors.joining("\n"));
    }
}

6. 控制器(Controller层)

java复制代码
/**
 * 文档管理接口
 * 安全建议:生产环境应添加权限校验
 */
@RestController
@RequestMapping("/api/documents")
public class DocumentController {
    private final VectorizationService vectorizationService;
    private final MilvusService milvusService;

    public DocumentController(VectorizationService vectorizationService,
                             MilvusService milvusService) {
        this.vectorizationService = vectorizationService;
        this.milvusService = milvusService;
    }

    /**
     * 文档上传接口
     * 优化建议:
     * - 添加文件大小限制
     * - 支持批量上传
     */
    @PostMapping
    public ResponseEntity<Map<String, Object>> uploadDocument(
        @RequestBody Map<String, String> request) {
        
        String content = request.get("content");
        if (content == null || content.isEmpty()) {
            throw new IllegalArgumentException("内容不能为空");
        }
        
        float[] vector = vectorizationService.vectorize(content);
        milvusService.storeDocument(content, vector);
        
        return ResponseEntity.ok(Map.of(
            "status", "success",
            "message", "文档存储成功",
            "vector_dim", vector.length
        ));
    }
}

/**
 * 问答服务接口
 */
@RestController
@RequestMapping("/api/qa")
public class QAController {
    private final AugmentationService augmentationService;
    private final GenerationService generationService;

    public QAController(AugmentationService augmentationService,
                       GenerationService generationService) {
        this.augmentationService = augmentationService;
        this.generationService = generationService;
    }

    /**
     * 问答接口
     * 性能优化建议:
     * - 添加异步处理
     * - 实现请求限流
     */
    @PostMapping
    public ResponseEntity<Map<String, Object>> answerQuestion(
        @RequestBody Map<String, String> request) {
        
        String question = request.get("question");
        if (question == null || question.isEmpty()) {
            throw new IllegalArgumentException("问题不能为空");
        }
        
        List<String> contexts = augmentationService.enhancedRetrieval(question);
        String answer = generationService.generateAnswer(question, contexts);
        
        return ResponseEntity.ok(Map.of(
            "question", question,
            "answer", answer,
            "references", contexts.subList(0, Math.min(3, contexts.size()))
        ));
    }
}

系统架构说明

  1. 数据流向

    复制代码
    用户请求
      → Spring Boot控制器
      → 向量化服务(调用模型API)
      → Milvus向量检索
      → 增强处理(重排序/扩展)
      → 生成服务(大模型API)
      → 返回结构化答案
    
  2. 关键技术点

    • 向量化:使用Sentence-BERT等模型生成文本嵌入
    • 向量检索:基于Milvus的近似最近邻搜索(ANN)
    • 增强处理:查询扩展、结果重排序等技术提升召回率
    • 生成控制:通过Prompt Engineering控制生成质量
  3. 性能优化方向

    • 实现批量处理接口
    • 添加缓存层(Redis)
    • 支持异步处理
    • 引入熔断机制(Resilience4j)

本实现方案可直接部署使用,需配合以下基础设施:

  • 模型推理服务(如FastAPI部署的NLP模型)
  • Milvus向量数据库(建议集群模式)
  • 大模型API服务(如GPT-3.5/4)