RAG的原理及应用

135 阅读9分钟

如何写好提示词这篇文章中,我们说明了该如何向大模型提问,但是有的时候我们会发现,即使提示词已经写的天花乱坠了,大模型还是在胡说八道,这是为什么?

原因呢,通常有以下几个:

没问清楚

就好比你是一名java程序员,有一天,产品经理问你,jvm该怎么优化?你脑子里瞬间反应出一大堆jvm知识,然后开始balabala,产品直接打断了你,说:你就告诉我该怎么优化就行,不用说那么多。

这个场景下问题出在哪呢?就是提的问题有问题,jvm优化,优化什么?gc?内存管理?还是啥?同理,如果提示词写的不好,大模型会找到一大堆看上去有关系但实际没什么关系的知识,一股脑的丢出来。

这个场景说的是提示词工程

缺乏背景知识

你现在还是一名java程序员,有一天,产品经理问你,go语言的垃圾回收的三色标记法是怎么回事?你表示我TM怎么知道,我又不写go。。。。

但如果这个时候你旁边的同事桌子上正好有一本《go语言从入门到放弃》,你是不是就可以翻翻书,然后告诉产品三色标记法是:如此如此,这般这般。

这个场景说的就是RAG(Retrieval AND Generation)

能力不足

你是仍然是一名java程序员,产品再次找到你,让你画一下当前需求的PRD,你表示我TM又不是产品,画可以但是我得去学一段时间,于是你废寝忘食学了三天三夜,画出了PRD图,但是你发现你忘记了psvm是什么。。。

这个场景说的就是微调

微调是可以补足模型缺失的能力,但也有可能把模型原有的能力调没

所以一般我们想从大模型那里得到想要的信息,会有三种方法,提示词工程,RAG,模型微调,

这三种方法的成本:提示词工程 < RAG < 模型微调,但是这三种方式的效果,要看具体场景,提示词工程之前的文章说过了,这篇文章,我们来看RAG

什么是RAG

RAG,中文名检索增强生成,是一种结合信息检索(Retrieval)和文本生成(Generation)的技术,RAG通过实时检索相关文档或信息,并将其作为上下文输入到生成模型中,从而提高生成结果的时效性和准确性。说人话就是,大模型的数据集里没有的东西,我们查好了发给它整理好再发给我们。

RAG的优势

知识时效性问题

大模型的训练数据通常是静态的,无法涵盖最新信息,而RAG可以检索外部知识库实时更新信息。

减少模型幻觉

通过引入外部知识,RAG能够减少模型生成 虚假或不准确内容的可能性。

提升专业领域回答质量

RAG能够结合垂直领域的专业知识 库,生成更具专业深度的回答

RAG的核心原理与流程

数据预处理

  1. 知识库构建:收集并整理文档,网页,数据库等多源数据,构建外部知识库

  2. 文档分块:将文档切分为适当大小的片段,方便后续检索,分块策略需要在语义完整与检索效率间权衡

  3. 向量化处理:使用嵌入模型将文本转换为向量,存储到向量数据库

检索阶段

  1. 查询处理:将用户输入的问题转化为向量,并在向量数据库中进行相似度检索,找到最相关的文本片段。
  2. 重排序:对检索结果进行相关性排序,选择最相关的片段作为生成阶段的输入

生成阶段

  1. 上下文组装:将检索到的文本片段与用户问题结合,形成增强的上下文输入
  2. 生成回答:大语言模型基于增强的上下文生成最终回答

图片.png

下面我们使用deepseek_v3+faiss来完成这一过程,

pdf文档来自百度文库

大模型服务使用阿里的灵积模型服务

首先,读取pdf文件并提取文件内容和页码信息

from PyPDF2 import PdfReader
from langchain.chains import RetrievalQA
from langchain_openai import OpenAI
from langchain_community.callbacks.manager import get_openai_callback
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_community.vectorstores import FAISS
from typing import List, Tuple
import os
import pickle

DASHSCOPE_API_KEY = ''

def extract_text_with_page_numbers(pdf) -> Tuple[str, List[int]]:
    """
    从PDF中提取文本并记录每行文本对应的页码

    参数:
        pdf: PDF文件对象

    返回:
        text: 提取的文本内容
        page_numbers: 每行文本对应的页码列表
    """
    text = ""
    page_numbers = []

    for page_number, page in enumerate(pdf.pages, start=1):
        extracted_text = page.extract_text()
        if extracted_text:
            text += extracted_text
            page_numbers.extend([page_number] * len(extracted_text.split("\n")))
        else:
            print(f"No text found on page {page_number}.")

    return text, page_numbers

pdf_reader = PdfReader('./浦发上海浦东发展银行西安分行个金客户经理考核办法.pdf')
# 提取文本和页码信息
text, page_numbers = extract_text_with_page_numbers(pdf_reader)

对文本进行处理,保存到向量数据库并持久化

def process_text_with_splitter(text: str, page_numbers: List[int], save_path: str = None) -> FAISS:
    """
    处理文本并创建向量存储

    参数:
        text: 提取的文本内容
        page_numbers: 每行文本对应的页码列表
        save_path: 可选,保存向量数据库的路径

    返回:
        knowledgeBase: 基于FAISS的向量存储对象
    """
    # 创建文本分割器,用于将长文本分割成小块
    text_splitter = RecursiveCharacterTextSplitter(
        separators=["\n\n", "\n", ".", " ", ""],
        chunk_size=1000,
        chunk_overlap=200,
        length_function=len,
    )

    # 分割文本
    chunks = text_splitter.split_text(text)
    # Logger.debug(f"Text split into {len(chunks)} chunks.")
    print(f"文本被分割成 {len(chunks)} 个块。")

    # 创建嵌入模型
    embeddings = DashScopeEmbeddings(
        model="text-embedding-v1",
        dashscope_api_key=DASHSCOPE_API_KEY,
    )

    # 从文本块创建知识库
    knowledgeBase = FAISS.from_texts(chunks, embeddings)
    print("已从文本块创建知识库。")

    # 存储每个文本块对应的页码信息
    page_numbers_unique = list(set(page_numbers))
    page_info = {chunk: page_numbers_unique[i] for i, chunk in enumerate(chunks)}

    # 如果提供了保存路径,则保存向量数据库和页码信息
    if save_path:
        # 确保目录存在
        os.makedirs(save_path, exist_ok=True)

        # 保存FAISS向量数据库
        knowledgeBase.save_local(save_path)
        print(f"向量数据库已保存到: {save_path}")

        # 保存页码信息到同一目录
        with open(os.path.join(save_path, "page_info.pkl"), "wb") as f:
            pickle.dump(page_info, f)
        print(f"页码信息已保存到: {os.path.join(save_path, 'page_info.pkl')}")

    return knowledgeBase

# 处理文本并创建知识库,同时保存到磁盘
save_dir = "./vector_db"
knowledgeBase = process_text_with_splitter(text, page_numbers, save_path=save_dir)

创建模型服务连接对象

from langchain_community.llms import Tongyi

llm = Tongyi(model_name="deepseek-v3", dashscope_api_key=DASHSCOPE_API_KEY)  # qwen-turbo

从本地加载向量数据库

def load_knowledge_base(load_path: str, embeddings=None) -> FAISS:
    """
    从磁盘加载向量数据库和页码信息

    参数:
        load_path: 向量数据库的保存路径
        embeddings: 可选,嵌入模型。如果为None,将创建一个新的DashScopeEmbeddings实例

    返回:
        knowledgeBase: 加载的FAISS向量数据库对象
    """
    # 如果没有提供嵌入模型,则创建一个新的
    if embeddings is None:
        embeddings = DashScopeEmbeddings(
            model="text-embedding-v1",
            dashscope_api_key=DASHSCOPE_API_KEY,
        )

    # 加载FAISS向量数据库,添加allow_dangerous_deserialization=True参数以允许反序列化
    knowledgeBase = FAISS.load_local(load_path, embeddings, allow_dangerous_deserialization=True)
    print(f"向量数据库已从 {load_path} 加载。")

    # 加载页码信息
    page_info_path = os.path.join(load_path, "page_info.pkl")
    if os.path.exists(page_info_path):
        with open(page_info_path, "rb") as f:
            page_info = pickle.load(f)
        knowledgeBase.page_info = page_info
        print("页码信息已加载。")
    else:
        print("警告: 未找到页码信息文件。")

    return knowledgeBase

# 创建嵌入模型
embeddings = DashScopeEmbeddings(
    model="text-embedding-v1",
    dashscope_api_key=DASHSCOPE_API_KEY,
)
# 从磁盘加载向量数据库
loaded_knowledgeBase = load_knowledge_base("./vector_db", embeddings)

到这里我们完成了外部数据存储到向量数据库的部分,下面我们看一下用户查询

# 设置查询问题
# 设置查询问题
query = "客户经理被投诉了,投诉一次扣多少分"
# query = "客户经理每年评聘申报时间是怎样的?"
if query:
    # 执行相似度搜索,找到与查询相关的文档
    docs = loaded_knowledgeBase.similarity_search(query)

    # 加载问答链
    # chain = load_qa_chain(llm, chain_type="stuff")
    chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type='stuff',
        retriever = loaded_knowledgeBase.as_retriever()
    )
    # 准备输入数据
    input_data = {"input_documents": docs, "query": query}

    # 使用回调函数跟踪API调用成本
    with get_openai_callback() as cost:
        # 执行问答链
        response = chain.invoke(input=input_data)
        print(f"查询已处理。成本: {cost}")
        print(f'问题:{response.get("query")}')
        print(f'答案:{response.get("result")}')
        print("来源:")

    # 记录唯一的页码
    unique_pages = set()

    # 显示每个文档块的来源页码
    for doc in docs:
        text_content = getattr(doc, "page_content", "")
        source_page = knowledgeBase.page_info.get(
            text_content.strip(), "未知"
        )

        if source_page not in unique_pages:
            unique_pages.add(source_page)
            print(f"文本块页码: {source_page}")

我们来看一下输出:

提取的文本长度: 3881 个字符。
文本被分割成 5 个块。
已从文本块创建知识库。
向量数据库已保存到: ./vector_db
页码信息已保存到: ./vector_db/page_info.pkl
向量数据库已从 ./vector_db 加载。
页码信息已加载。
查询已处理。成本: Tokens Used: 0
	Prompt Tokens: 0
		Prompt Tokens Cached: 0
	Completion Tokens: 0
		Reasoning Tokens: 0
Successful Requests: 1
Total Cost (USD): $0.0
问题:客户经理被投诉了,投诉一次扣多少分
答案:根据提供的信息,如果客户经理被投诉一次,每次投诉会扣2分。具体参考以下内容:

"客户服务效率低,态度生硬或不及时为客户提供维护服务,有客户投诉的,每投诉一次扣 2分"
来源:
文本块页码: 3
文本块页码: 5
文本块页码: 1
文本块页码: 4

完整代码:

from PyPDF2 import PdfReader
from langchain.chains.question_answering import load_qa_chain
from langchain_openai import OpenAI
from langchain_community.callbacks.manager import get_openai_callback
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_community.vectorstores import FAISS
from typing import List, Tuple
from langchain.chains import RetrievalQA
import os
import pickle
from langchain_community.llms import Tongyi



DASHSCOPE_API_KEY = '你的api key'


def extract_text_with_page_numbers(pdf) -> Tuple[str, List[int]]:
    """
    从PDF中提取文本并记录每行文本对应的页码

    参数:
        pdf: PDF文件对象

    返回:
        text: 提取的文本内容
        page_numbers: 每行文本对应的页码列表
    """
    text = ""
    page_numbers = []

    for page_number, page in enumerate(pdf.pages, start=1):
        extracted_text = page.extract_text()
        if extracted_text:
            text += extracted_text
            page_numbers.extend([page_number] * len(extracted_text.split("\n")))
        else:
            print(f"No text found on page {page_number}.")

    return text, page_numbers

def process_text_with_splitter(text: str, page_numbers: List[int], save_path: str = None) -> FAISS:
    """
    处理文本并创建向量存储

    参数:
        text: 提取的文本内容
        page_numbers: 每行文本对应的页码列表
        save_path: 可选,保存向量数据库的路径

    返回:
        knowledgeBase: 基于FAISS的向量存储对象
    """
    # 创建文本分割器,用于将长文本分割成小块
    text_splitter = RecursiveCharacterTextSplitter(
        separators=["\n\n", "\n", ".", " ", ""],
        chunk_size=1000,
        chunk_overlap=200,
        length_function=len,
    )

    # 分割文本
    chunks = text_splitter.split_text(text)
    # Logger.debug(f"Text split into {len(chunks)} chunks.")
    print(f"文本被分割成 {len(chunks)} 个块。")

    # 创建嵌入模型
    embeddings = DashScopeEmbeddings(
        model="text-embedding-v1",
        dashscope_api_key=DASHSCOPE_API_KEY,
    )

    # 从文本块创建知识库
    knowledgeBase = FAISS.from_texts(chunks, embeddings)
    print("已从文本块创建知识库。")

    # 存储每个文本块对应的页码信息
    page_numbers_unique = list(set(page_numbers))
    page_info = {chunk: page_numbers_unique[i] for i, chunk in enumerate(chunks)}
    knowledgeBase.page_info = page_info

    # 如果提供了保存路径,则保存向量数据库和页码信息
    if save_path:
        # 确保目录存在
        os.makedirs(save_path, exist_ok=True)

        # 保存FAISS向量数据库
        knowledgeBase.save_local(save_path)
        print(f"向量数据库已保存到: {save_path}")

        # 保存页码信息到同一目录
        with open(os.path.join(save_path, "page_info.pkl"), "wb") as f:
            pickle.dump(page_info, f)
        print(f"页码信息已保存到: {os.path.join(save_path, 'page_info.pkl')}")

def load_knowledge_base(load_path: str, embeddings=None) -> FAISS:
    """
    从磁盘加载向量数据库和页码信息

    参数:
        load_path: 向量数据库的保存路径
        embeddings: 可选,嵌入模型。如果为None,将创建一个新的DashScopeEmbeddings实例

    返回:
        knowledgeBase: 加载的FAISS向量数据库对象
    """
    # 如果没有提供嵌入模型,则创建一个新的
    if embeddings is None:
        embeddings = DashScopeEmbeddings(
            model="text-embedding-v1",
            dashscope_api_key=DASHSCOPE_API_KEY,
        )

    # 加载FAISS向量数据库,添加allow_dangerous_deserialization=True参数以允许反序列化
    knowledgeBase = FAISS.load_local(load_path, embeddings, allow_dangerous_deserialization=True)
    print(f"向量数据库已从 {load_path} 加载。")

    # 加载页码信息
    page_info_path = os.path.join(load_path, "page_info.pkl")
    if os.path.exists(page_info_path):
        with open(page_info_path, "rb") as f:
            page_info = pickle.load(f)
        knowledgeBase.page_info = page_info
        print("页码信息已加载。")
    else:
        print("警告: 未找到页码信息文件。")

    return knowledgeBase

# 读取PDF文件
pdf_reader = PdfReader('./浦发上海浦东发展银行西安分行个金客户经理考核办法.pdf')
# 提取文本和页码信息
text, page_numbers = extract_text_with_page_numbers(pdf_reader)

print(f"提取的文本长度: {len(text)} 个字符。")

# 处理文本并创建知识库,同时保存到磁盘
save_dir = "./vector_db"
knowledgeBase = process_text_with_splitter(text, page_numbers, save_path=save_dir)


# 创建嵌入模型
embeddings = DashScopeEmbeddings(
    model="text-embedding-v1",
    dashscope_api_key=DASHSCOPE_API_KEY,
)
# 从磁盘加载向量数据库
loaded_knowledgeBase = load_knowledge_base("./vector_db", embeddings)

llm = Tongyi(model_name="deepseek-v3", dashscope_api_key=DASHSCOPE_API_KEY)


while True:
    query = input("问题(q:退出):")
    if query == "Q".lower():
        break
    if query:
        # 执行相似度搜索,找到与查询相关的文档
        docs = loaded_knowledgeBase.similarity_search(query)

        # 加载问答链
        # chain = load_qa_chain(llm, chain_type="stuff")
        chain = RetrievalQA.from_chain_type(
            llm=llm,
            chain_type='stuff',
            retriever=loaded_knowledgeBase.as_retriever()
        )
        # 准备输入数据
        input_data = {"input_documents": docs, "query": query}

        # 使用回调函数跟踪API调用成本
        with get_openai_callback() as cost:
            # 执行问答链
            response = chain.invoke(input=input_data)
            print(f"查询已处理。成本: {cost}")
            print(f'问题:{response.get("query")}')
            print(f'答案:{response.get("result")}')
            print("来源:")

        # 记录唯一的页码
        unique_pages = set()

        # 显示每个文档块的来源页码
        for doc in docs:
            text_content = getattr(doc, "page_content", "")
            source_page = loaded_knowledgeBase.page_info.get(
                text_content.strip(), "未知"
            )

            if source_page not in unique_pages:
                unique_pages.add(source_page)
                print(f"文本块编号: {source_page}")

运行:

提取的文本长度: 3881 个字符。
文本被分割成 5 个块。
已从文本块创建知识库。
向量数据库已保存到: ./vector_db
页码信息已保存到: ./vector_db/page_info.pkl
向量数据库已从 ./vector_db 加载。
页码信息已加载。
问题(q:退出):个人资产质量发生跨月逾期,扣多少分
查询已处理。成本: Tokens Used: 0
	Prompt Tokens: 0
		Prompt Tokens Cached: 0
	Completion Tokens: 0
		Reasoning Tokens: 0
Successful Requests: 1
Total Cost (USD): $0.0
问题:个人资产质量发生跨月逾期,扣多少分
答案:根据提供的信息,个人资产质量发生跨月逾期的扣分标准如下:

- **单笔不超过10万元,当季收回者**,扣1分。
- **2笔以上累计金额不超过20万元,当季收回者**,扣2分。
- **累计超过20万元以上的**,扣4分。
- **逾期超过3个月,无论金额大小和笔数**,扣10分。

因此,具体的扣分取决于逾期的金额、笔数以及逾期时间。
来源:
文本块页码: 3
文本块页码: 5
文本块页码: 1
文本块页码: 2
问题(q:退出):q