1、什么是RAG?
- RAG 是 Retrieval-Augmented Generation 的缩写,中文可以翻译为 “检索增强生成”
- 简单来说,它是一种 结合向量检索和大模型生成的问答/文本生成方法,解决大模型知识覆盖不全的问题。
- 这是询问AI得到的答案。按照我的理解,就是从数据库找个类似的数据,然后把这个数据和用户询问的数据一并发给大模型。大模型就会有对用户询问的数据有了依据

2、什么又是向量检索呢?
- 提到这,就想到,高中学的向量,以下图为例:这是个二维向量。从方向角度来看,A和B是比较相似的、以大小来看,C和B是相似的。

- 然而在这里,大模型的向量不止2个维度,如下图的qdrant截图,拥有1024个维度,这个就不止x,y属性。例如:一个人可以有(姓名、性别、年龄、生日、国籍、身高、体重、血型、BMI、眼睛颜色、头发颜色...)等等
- 所以,通过多维度的相比就能匹配到是否相似,从而实现“语义化搜索”

3、怎么实现
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-reactor</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-qdrant</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-easy-rag</artifactId>
</dependency>
</dependencies>
package paperfly.config;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
import dev.langchain4j.data.document.parser.TextDocumentParser;
import dev.langchain4j.data.document.parser.apache.tika.ApacheTikaDocumentParser;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.store.embedding.qdrant.QdrantEmbeddingStore;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.QdrantGrpcClient;
import io.qdrant.client.grpc.Collections;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import paperfly.service.ChatAssistant;
import java.util.List;
@Configuration
@Slf4j
public class LLMConfig {
@Bean
public ChatModel chatModel() {
return OpenAiChatModel.builder()
.apiKey(System.getenv("aliAi-key"))
.modelName("qwen-plus")
.logRequests(true)
.logResponses(true)
.baseUrl("https://dashscope.aliyuncs.com/compatible-mode/v1")
.build();
}
@Bean
public EmbeddingModel embeddingModel() {
return OpenAiEmbeddingModel.builder()
.apiKey(System.getenv("aliAi-key"))
.modelName("text-embedding-v3")
.logRequests(true)
.logResponses(true)
.baseUrl("https://dashscope.aliyuncs.com/compatible-mode/v1")
.build();
}
@Bean
public QdrantClient qdrantClient() {
QdrantClient client =
new QdrantClient(
QdrantGrpcClient.newBuilder("127.0.0.1", 6334, false)
.build());
return client;
}
@Bean
public EmbeddingStore<TextSegment> embeddingStore() {
EmbeddingStore<TextSegment> embeddingStore =
QdrantEmbeddingStore.builder()
.host("127.0.0.1")
.port(6334)
.collectionName("doc-qdrant")
.build();
return embeddingStore;
}
@Bean
public ChatAssistant chatAssistant(ChatModel chatModel,EmbeddingStore<TextSegment> embeddingStore,EmbeddingModel embeddingModel) {
EmbeddingStoreContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(5).build();
return AiServices
.builder(ChatAssistant.class)
.chatModel(chatModel)
.chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages(10))
.contentRetriever(contentRetriever)
.build();
}
}
- 创建qdrant集合
- 注意setSize(1024)配置,要与你使用的模型支持的维度一致
Embedding queryEmbedding = embeddedModel.embed("00000是什么?").content();
System.out.println("维度: " + queryEmbedding.vector().length);
@RequestMapping(value = "/rag01/createCollection")
public void createCollection() throws IOException {
Collections.VectorParams vectorParams = Collections.VectorParams.newBuilder()
.setDistance(Collections.Distance.Cosine)
.setSize(1024)
.build();
qdrantClient.createCollectionAsync("doc-qdrant", vectorParams);
}
@RequestMapping(value = "/rag01/add")
public String add() throws IOException {
Document document = FileSystemDocumentLoader.loadDocument("D:\01-doc\公司文件\2024_08_16_软件产品研发实施准则与规范v1.1\Java开发手册(黄山版).pdf", new ApacheTikaDocumentParser());
document.metadata().put("author", "paperfly");
DocumentByParagraphSplitter splitter = new DocumentByParagraphSplitter(500,50);
List<TextSegment> segments = splitter.split(document);
int batchSize = 10;
for (int i = 0; i < segments.size(); i += batchSize) {
int end = Math.min(i + batchSize, segments.size());
List<TextSegment> batch = segments.subList(i, end);
List<Embedding> embeddings = embeddedModel.embedAll(batch).content();
embeddingStore.addAll(embeddings, batch);
}
return "Inserted " + segments.size() + " chunks into Qdrant";
}
@RequestMapping(value = "/rag01/ask")
public Object ask() throws IOException {
return chatAssistant.chat("00000是什么?");
}
package paperfly.config;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
import dev.langchain4j.data.document.parser.TextDocumentParser;
import dev.langchain4j.data.document.parser.apache.tika.ApacheTikaDocumentParser;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.store.embedding.qdrant.QdrantEmbeddingStore;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.QdrantGrpcClient;
import io.qdrant.client.grpc.Collections;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import paperfly.service.ChatAssistant;
import java.util.List;
@Configuration
@Slf4j
public class LLMConfig {
@Bean
public ChatModel chatModel() {
return OpenAiChatModel.builder()
.apiKey(System.getenv("aliAi-key"))
.modelName("qwen-plus")
.logRequests(true)
.logResponses(true)
.baseUrl("https://dashscope.aliyuncs.com/compatible-mode/v1")
.build();
}
@Bean
public EmbeddingModel embeddingModel() {
return OpenAiEmbeddingModel.builder()
.apiKey(System.getenv("aliAi-key"))
.modelName("text-embedding-v3")
.logRequests(true)
.logResponses(true)
.baseUrl("https://dashscope.aliyuncs.com/compatible-mode/v1")
.build();
}
@Bean
public QdrantClient qdrantClient() {
QdrantClient client =
new QdrantClient(
QdrantGrpcClient.newBuilder("127.0.0.1", 6334, false)
.build());
return client;
}
@Bean
public EmbeddingStore<TextSegment> embeddingStore() {
EmbeddingStore<TextSegment> embeddingStore =
QdrantEmbeddingStore.builder()
.host("127.0.0.1")
.port(6334)
.collectionName("doc-qdrant")
.build();
return embeddingStore;
}
@Bean
public ChatAssistant chatAssistant(ChatModel chatModel,EmbeddingStore<TextSegment> embeddingStore,EmbeddingModel embeddingModel) {
EmbeddingStoreContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(5).build();
return AiServices
.builder(ChatAssistant.class)
.chatModel(chatModel)
.chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages(10))
.contentRetriever(contentRetriever)
.build();
}
}
package paperfly.service;
public interface ChatAssistant {
String chat(String userMessage);
}
package paperfly.controller;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.DocumentSplitter;
import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
import dev.langchain4j.data.document.parser.apache.tika.ApacheTikaDocumentParser;
import dev.langchain4j.data.document.splitter.DocumentByParagraphSplitter;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.store.embedding.*;
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.grpc.Collections;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import paperfly.service.ChatAssistant;
import java.io.IOException;
import java.util.List;
@RestController
@RequestMapping("/lc4j")
@Slf4j
public class LangChain4JChatRag01ChatController {
@Autowired
private EmbeddingModel embeddedModel;
@Autowired
private ChatAssistant chatAssistant;
@Autowired
private QdrantClient qdrantClient;
@Autowired
private EmbeddingStore<TextSegment> embeddingStore;
@RequestMapping(value = "/rag01/chat")
public Object chat(String prompt) throws IOException {
Response<Embedding> embed = embeddedModel.embed(prompt);
Embedding content = embed.content();
String string = content.toString();
log.info("string: {}", string);
return string;
}
@RequestMapping(value = "/rag01/createCollection")
public void createCollection() throws IOException {
Collections.VectorParams vectorParams = Collections.VectorParams.newBuilder()
.setDistance(Collections.Distance.Cosine)
.setSize(1024)
.build();
qdrantClient.createCollectionAsync("doc-qdrant", vectorParams);
}
@RequestMapping(value = "/rag01/add")
public String add() throws IOException {
Document document = FileSystemDocumentLoader.loadDocument("D:\01-doc\公司文件\2024_08_16_软件产品研发实施准则与规范v1.1\Java开发手册(黄山版).pdf", new ApacheTikaDocumentParser());
document.metadata().put("author", "paperfly");
DocumentByParagraphSplitter splitter = new DocumentByParagraphSplitter(500,50);
List<TextSegment> segments = splitter.split(document);
int batchSize = 10;
for (int i = 0; i < segments.size(); i += batchSize) {
int end = Math.min(i + batchSize, segments.size());
List<TextSegment> batch = segments.subList(i, end);
List<Embedding> embeddings = embeddedModel.embedAll(batch).content();
embeddingStore.addAll(embeddings, batch);
}
return "Inserted " + segments.size() + " chunks into Qdrant";
}
@RequestMapping(value = "/rag01/ask")
public Object ask() throws IOException {
return chatAssistant.chat("00000是什么?");
}
@RequestMapping(value = "/rag01/query")
public Object query() throws IOException {
Embedding queryEmbedding = embeddedModel.embed("00000是什么?").content();
System.out.println("维度: " + queryEmbedding.vector().length);
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(queryEmbedding)
.maxResults(1)
.build();
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(embeddingSearchRequest).matches();
EmbeddingMatch<TextSegment> embeddingMatch = matches.get(0);
System.out.println(embeddingMatch.score());
System.out.println(embeddingMatch.embedded().text());
return embeddingMatch.embedded().text();
}
@RequestMapping(value = "/rag01/query2")
public Object query2() throws IOException {
Embedding queryEmbedding = embeddedModel.embed("咏鸡说的是什么?").content();
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(queryEmbedding)
.filter(MetadataFilterBuilder.metadataKey("author").isEqualTo("paperfly2"))
.maxResults(1)
.build();
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(embeddingSearchRequest).matches();
EmbeddingMatch<TextSegment> embeddingMatch = matches.get(0);
System.out.println(embeddingMatch.score());
System.out.println(embeddingMatch.embedded().text());
return embeddingMatch.embedded().text();
}
}