RAG 与向量数据库实战:文档问答系统搭建

71 阅读4分钟

1、什么是RAG?

  • RAG 是 Retrieval-Augmented Generation 的缩写,中文可以翻译为 “检索增强生成”
  • 简单来说,它是一种 结合向量检索和大模型生成的问答/文本生成方法,解决大模型知识覆盖不全的问题。
  • 这是询问AI得到的答案。按照我的理解,就是从数据库找个类似的数据,然后把这个数据和用户询问的数据一并发给大模型。大模型就会有对用户询问的数据有了依据

image.png

2、什么又是向量检索呢?

  • 提到这,就想到,高中学的向量,以下图为例:这是个二维向量。从方向角度来看,A和B是比较相似的、以大小来看,C和B是相似的。

image.png

  • 然而在这里,大模型的向量不止2个维度,如下图的qdrant截图,拥有1024个维度,这个就不止x,y属性。例如:一个人可以有(姓名、性别、年龄、生日、国籍、身高、体重、血型、BMI、眼睛颜色、头发颜色...)等等
  • 所以,通过多维度的相比就能匹配到是否相似,从而实现“语义化搜索”

image.png

3、怎么实现

  • maven配置
<dependencies>
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-web</artifactId>
    </dependency>
    <!--        基础API-->
    <dependency>
        <groupId>dev.langchain4j</groupId>
        <artifactId>langchain4j-open-ai</artifactId>
    </dependency>
    <!--        高级 AI 服务 API-->
    <dependency>
        <groupId>dev.langchain4j</groupId>
        <artifactId>langchain4j</artifactId>
    </dependency>
    <dependency>
        <groupId>org.projectlombok</groupId>
        <artifactId>lombok</artifactId>
    </dependency>

    <dependency>
        <groupId>dev.langchain4j</groupId>
        <artifactId>langchain4j-reactor</artifactId>
    </dependency>
    <!--   qdrant操作的maven     -->
    <dependency>
        <groupId>dev.langchain4j</groupId>
        <artifactId>langchain4j-qdrant</artifactId>
    </dependency>
    <!--  RAG得简单使用maven配置      -->
    <dependency>
        <groupId>dev.langchain4j</groupId>
        <artifactId>langchain4j-easy-rag</artifactId>
    </dependency>
</dependencies>
  • 模型、qdrant配置
package paperfly.config;


import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
import dev.langchain4j.data.document.parser.TextDocumentParser;
import dev.langchain4j.data.document.parser.apache.tika.ApacheTikaDocumentParser;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.store.embedding.qdrant.QdrantEmbeddingStore;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.QdrantGrpcClient;
import io.qdrant.client.grpc.Collections;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import paperfly.service.ChatAssistant;

import java.util.List;

@Configuration
@Slf4j
public class LLMConfig {
    @Bean
    public ChatModel chatModel() {
        return OpenAiChatModel.builder()
                .apiKey(System.getenv("aliAi-key"))
                .modelName("qwen-plus")
                .logRequests(true)
                .logResponses(true)
                .baseUrl("https://dashscope.aliyuncs.com/compatible-mode/v1")
                .build();
    }
    @Bean
    public EmbeddingModel embeddingModel() {
        return OpenAiEmbeddingModel.builder()
                .apiKey(System.getenv("aliAi-key"))
                .modelName("text-embedding-v3")
                .logRequests(true)
                .logResponses(true)
                .baseUrl("https://dashscope.aliyuncs.com/compatible-mode/v1")
                .build();
    }

    @Bean
    public QdrantClient qdrantClient() {
        QdrantClient client =
                new QdrantClient(
                        QdrantGrpcClient.newBuilder("127.0.0.1", 6334, false)
                                .build());

        return client;
    }

    @Bean
    public EmbeddingStore<TextSegment> embeddingStore() {
        EmbeddingStore<TextSegment> embeddingStore =
                QdrantEmbeddingStore.builder()
                        .host("127.0.0.1")
                        .port(6334)
                        .collectionName("doc-qdrant")
                        .build();
        return embeddingStore;
    }

    @Bean
    public ChatAssistant chatAssistant(ChatModel chatModel,EmbeddingStore<TextSegment> embeddingStore,EmbeddingModel embeddingModel) {
        EmbeddingStoreContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
                .embeddingStore(embeddingStore)
                .embeddingModel(embeddingModel)
                .maxResults(5).build();
        return AiServices
                .builder(ChatAssistant.class)
                .chatModel(chatModel)
                .chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages(10))
                .contentRetriever(contentRetriever)
                .build();
    }
}
  • 创建qdrant集合
    • 注意setSize(1024)配置,要与你使用的模型支持的维度一致
//具体模型维度,可以通过下面的代码查看
Embedding queryEmbedding = embeddedModel.embed("00000是什么?").content();
System.out.println("维度: " + queryEmbedding.vector().length);
@RequestMapping(value = "/rag01/createCollection")
public void createCollection() throws IOException {
    Collections.VectorParams vectorParams = Collections.VectorParams.newBuilder()
            .setDistance(Collections.Distance.Cosine)
            .setSize(1024)
            .build();
    qdrantClient.createCollectionAsync("doc-qdrant", vectorParams);
}
  • 将文件向量化
@RequestMapping(value = "/rag01/add")
public String add() throws IOException {
    Document document = FileSystemDocumentLoader.loadDocument("D:\01-doc\公司文件\2024_08_16_软件产品研发实施准则与规范v1.1\Java开发手册(黄山版).pdf", new ApacheTikaDocumentParser());
    document.metadata().put("author", "paperfly");
    // 2. 按段落切分
    DocumentByParagraphSplitter splitter = new DocumentByParagraphSplitter(500,50);
    List<TextSegment> segments = splitter.split(document);

    // 3. 分批调用 embedding(一次最多 10 条)
    int batchSize = 10;
    for (int i = 0; i < segments.size(); i += batchSize) {
        int end = Math.min(i + batchSize, segments.size());
        List<TextSegment> batch = segments.subList(i, end);

        // 调用 embedding API
        List<Embedding> embeddings = embeddedModel.embedAll(batch).content();

        // 存入向量数据库
        embeddingStore.addAll(embeddings, batch);
    }
    return "Inserted " + segments.size() + " chunks into Qdrant";
}
  • 最后完成查询
@RequestMapping(value = "/rag01/ask")
public Object ask() throws IOException {
    return chatAssistant.chat("00000是什么?");
}
  • 完整的代码如下
package paperfly.config;


import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
import dev.langchain4j.data.document.parser.TextDocumentParser;
import dev.langchain4j.data.document.parser.apache.tika.ApacheTikaDocumentParser;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.store.embedding.qdrant.QdrantEmbeddingStore;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.QdrantGrpcClient;
import io.qdrant.client.grpc.Collections;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import paperfly.service.ChatAssistant;

import java.util.List;

@Configuration
@Slf4j
public class LLMConfig {
    @Bean
    public ChatModel chatModel() {
        return OpenAiChatModel.builder()
                .apiKey(System.getenv("aliAi-key"))
                .modelName("qwen-plus")
                .logRequests(true)
                .logResponses(true)
                .baseUrl("https://dashscope.aliyuncs.com/compatible-mode/v1")
                .build();
    }
    @Bean
    public EmbeddingModel embeddingModel() {
        return OpenAiEmbeddingModel.builder()
                .apiKey(System.getenv("aliAi-key"))
                .modelName("text-embedding-v3")
                .logRequests(true)
                .logResponses(true)
                .baseUrl("https://dashscope.aliyuncs.com/compatible-mode/v1")
                .build();
    }

    @Bean
    public QdrantClient qdrantClient() {
        QdrantClient client =
                new QdrantClient(
                        QdrantGrpcClient.newBuilder("127.0.0.1", 6334, false)
                                .build());

        return client;
    }

    @Bean
    public EmbeddingStore<TextSegment> embeddingStore() {
        EmbeddingStore<TextSegment> embeddingStore =
                QdrantEmbeddingStore.builder()
                        .host("127.0.0.1")
                        .port(6334)
                        .collectionName("doc-qdrant")
                        .build();
        return embeddingStore;
    }

    @Bean
    public ChatAssistant chatAssistant(ChatModel chatModel,EmbeddingStore<TextSegment> embeddingStore,EmbeddingModel embeddingModel) {
        EmbeddingStoreContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
                .embeddingStore(embeddingStore)
                .embeddingModel(embeddingModel)
                .maxResults(5).build();
        return AiServices
                .builder(ChatAssistant.class)
                .chatModel(chatModel)
                .chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages(10))
                .contentRetriever(contentRetriever)
                .build();
    }
}
package paperfly.service;

public interface ChatAssistant {
    String chat(String userMessage);
}
package paperfly.controller;

import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.DocumentSplitter;
import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
import dev.langchain4j.data.document.parser.apache.tika.ApacheTikaDocumentParser;
import dev.langchain4j.data.document.splitter.DocumentByParagraphSplitter;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.store.embedding.*;
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.grpc.Collections;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import paperfly.service.ChatAssistant;

import java.io.IOException;
import java.util.List;

@RestController
@RequestMapping("/lc4j")
@Slf4j
public class LangChain4JChatRag01ChatController {
    @Autowired
    private EmbeddingModel embeddedModel;
    @Autowired
    private ChatAssistant chatAssistant;
    @Autowired
    private QdrantClient qdrantClient;
    @Autowired
    private EmbeddingStore<TextSegment> embeddingStore;

    @RequestMapping(value = "/rag01/chat")
    public Object chat(String prompt) throws IOException {
        Response<Embedding> embed = embeddedModel.embed(prompt);


        Embedding content = embed.content();
        String string = content.toString();
        log.info("string: {}", string);
        return string;
    }

    @RequestMapping(value = "/rag01/createCollection")
    public void createCollection() throws IOException {
        Collections.VectorParams vectorParams = Collections.VectorParams.newBuilder()
                .setDistance(Collections.Distance.Cosine)
                .setSize(1024)
                .build();
        qdrantClient.createCollectionAsync("doc-qdrant", vectorParams);
    }

    @RequestMapping(value = "/rag01/add")
    public String add() throws IOException {
        Document document = FileSystemDocumentLoader.loadDocument("D:\01-doc\公司文件\2024_08_16_软件产品研发实施准则与规范v1.1\Java开发手册(黄山版).pdf", new ApacheTikaDocumentParser());
        document.metadata().put("author", "paperfly");
        // 2. 按段落切分
        DocumentByParagraphSplitter splitter = new DocumentByParagraphSplitter(500,50);
        List<TextSegment> segments = splitter.split(document);

        // 3. 分批调用 embedding(一次最多 10 条)
        int batchSize = 10;
        for (int i = 0; i < segments.size(); i += batchSize) {
            int end = Math.min(i + batchSize, segments.size());
            List<TextSegment> batch = segments.subList(i, end);

            // 调用 embedding API
            List<Embedding> embeddings = embeddedModel.embedAll(batch).content();

            // 存入向量数据库
            embeddingStore.addAll(embeddings, batch);
        }
        return "Inserted " + segments.size() + " chunks into Qdrant";
    }

    @RequestMapping(value = "/rag01/ask")
    public Object ask() throws IOException {
        return chatAssistant.chat("00000是什么?");
    }

    @RequestMapping(value = "/rag01/query")
    public Object query() throws IOException {
        Embedding queryEmbedding = embeddedModel.embed("00000是什么?").content();
        System.out.println("维度: " + queryEmbedding.vector().length);
        EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
                .queryEmbedding(queryEmbedding)
                .maxResults(1)
                .build();
        List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(embeddingSearchRequest).matches();
        EmbeddingMatch<TextSegment> embeddingMatch = matches.get(0);
        System.out.println(embeddingMatch.score());
        System.out.println(embeddingMatch.embedded().text());
        return embeddingMatch.embedded().text();
    }

    @RequestMapping(value = "/rag01/query2")
    public Object query2() throws IOException {
        Embedding queryEmbedding = embeddedModel.embed("咏鸡说的是什么?").content();
        EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
                .queryEmbedding(queryEmbedding)
                .filter(MetadataFilterBuilder.metadataKey("author").isEqualTo("paperfly2"))
                .maxResults(1)
                .build();
        List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(embeddingSearchRequest).matches();
        EmbeddingMatch<TextSegment> embeddingMatch = matches.get(0);
        System.out.println(embeddingMatch.score());
        System.out.println(embeddingMatch.embedded().text());
        return embeddingMatch.embedded().text();
    }
}