3小时搞定RAG理论+实战篇

47 阅读6分钟

本文较长,建议点赞收藏。更多AI大模型应用开发学习视频及资料,在智泊AI

本节知识点

  • RAG是什么 ??
  • RAG优势 ??
  • RAG实战

RAG(Retrieval Augmented Generation)

检索增强生成 (RAG)  是一种 AI 框架,它结合了传统的信息检索系统(如搜索引擎或数据库)与大型语言模型(LLMs)的生成能力。

RAG 的核心思想是:

  1. 检索 (Retrieval) : 当用户提出问题时,首先从一个外部的、权威的知识库中检索出与问题最相关的几段信息(上下文)。
  2. 增强 (Augmented) : 将检索到的这些信息作为额外的上下文,与用户原始的问题一起,“增强” LLM 的输入。
  3. 生成 (Generation) : LLM 在这个增强的上下文中生成回答。

RAG优势

  • 减少幻觉 (Hallucinations): LLM 倾向于“编造”不存在的事实。RAG 通过提供真实、可靠的外部信息,大大降低了 LLM 产生不准确或虚假信息的可能性。
  • 知识时效性:LLM 的训练数据是静态的。RAG 允许你使用最新的数据(例如,你 MySQL 数据库中每天更新的课程信息),而无需重新训练或微调 LLM。
  • 特定领域知识:LLM 可能对你的公司内部数据、特定行业的术语或小众知识了解甚少。RAG 使 LLM 能够回答这些特定领域的问题。
  • 可追溯性/可解释性: 由于回答是基于检索到的文档生成的,你可以很容易地提供引用来源 (source_documents),让用户知道答案来自哪里,增加了透明度和信任度。
  • 成本效益:与耗时且昂贵的 LLM 微调相比,RAG 通常是更经济高效的解决方案。你只需更新向量数据库即可。

RAG实战

AI课程规划+课程ID

课程ID对应数据库中的, 这些数值AI是不能识别的, 需要增强之后输出。

环境准备

这里用到的包有很多, 主要有这些:

  • streamlit
  • langchain
  • langchain-core
  • pymysql
  • faiss-cpu

我目前用到的库, 可以选择性复制:

aiohappyeyeballs==2.6.1  
aiohttp==3.11.13  
aiosignal==1.3.2  
altair==5.5.0  
annotated-types==0.7.0  
anyio==4.8.0  
attrs==25.1.0  
blinker==1.9.0  
cachetools==5.5.2  
certifi==2025.1.31  
charset-normalizer==3.4.1  
click==8.1.8  
dashscope==1.22.2  
dataclasses-json==0.6.7  
distro==1.9.0  
frozenlist==1.5.0  
gitdb==4.0.12  
GitPython==3.1.44  
greenlet==3.1.1  
h11==0.14.0  
httpcore==1.0.7  
httpx==0.28.1  
httpx-sse==0.4.0  
idna==3.10  
Jinja2==3.1.6  
jiter==0.9.0  
jsonpatch==1.33  
jsonpointer==3.0.0  
jsonschema==4.23.0  
jsonschema-specifications==2024.10.1  
langchain==0.3.20  
langchain-community==0.3.19  
langchain-core==0.3.45  
langchain-deepseek==0.1.2  
langchain-openai==0.3.8  
langchain-text-splitters==0.3.6  
langsmith==0.3.15  
MarkupSafe==3.0.2  
marshmallow==3.26.1  
multidict==6.1.0  
mypy-extensions==1.0.0  
narwhals==1.30.0  
numpy==2.2.3  
openai==1.66.3  
orjson==3.10.15  
packaging==24.2  
pandas==2.2.3  
pillow==11.1.0  
propcache==0.3.0  
protobuf==5.29.3  
pyarrow==19.0.1  
pydantic==2.10.6  
pydantic-settings==2.8.1  
pydantic_core==2.27.2  
pydeck==0.9.1  
python-dateutil==2.9.0.post0  
python-dotenv==1.0.1  
pytz==2025.1  
PyYAML==6.0.2  
referencing==0.36.2  
regex==2024.11.6  
requests==2.32.3  
requests-toolbelt==1.0.0  
rpds-py==0.23.1  
six==1.17.0  
smmap==5.0.2  
sniffio==1.3.1  
SQLAlchemy==2.0.39  
streamlit==1.43.1  
tenacity==9.0.0  
tiktoken==0.9.0  
toml==0.10.2  
tornado==6.4.2  
tqdm==4.67.1  
typing-inspect==0.9.0  
typing_extensions==4.12.2  
tzdata==2025.1  
urllib3==2.3.0  
watchdog==6.0.0  
websocket-client==1.8.0  
yarl==1.18.3  
zstandard==0.23.0  
PyMySQL==1.1.1  
faiss-cpu==1.11.0

当前这里也需要有api-key, 这个langchain框架大部分都支持的。

分步处理

  • 外部知识库(mysql)
# src.api.useCourse.py  
import re  
import pymysql  
import pandas as pd  
from langchain_core.documents import Document  
  
def fetch_and_preprocess_data(table_name, course_name, course_id):  
    print("开始查询数据...")  
    # 建立连接  
    conn = pymysql.connect(  
        host='****',  
        user='****',  
        password='****',  
        database='****'  
    )  
      
    try:  
        with conn.cursor() as cursor:   
            # 查询数据  
            query = f"SELECT {course_name}{course_id} FROM {table_name} LIMIT 20"  
            cursor.execute(query)  
            data = cursor.fetchall()  
            print("查询结果:")  
            print(data)  
            df = pd.DataFrame(data, columns=['course_name''course_id'])  
  
        # 文本预处理函数  
        def preprocess_text(text):  
            ifnot isinstance(text, str):  
                return""  
            text = re.sub(r'<.*?>''', text)       # 去除HTML标签  
            text = re.sub(r'\s+'' ', text).strip() # 压缩空白符  
            return text  
  
        df['cleaned_text'] = df[course_name].apply(preprocess_text)  
        print("✅ 数据预处理完成。")  
        # df.to_csv('data.csv', index=False)  
        return df  
    except Exception as e:  
        print(f"❌ 查询或预处理MySQL数据失败: {e}")  
        return pd.DataFrame()  
    finally:  
        if conn:  
            conn.close()  
  
  
  
def getMysqlDocuments():  
    # 示例:从 'tb_course' 表中提取 'course_name' 字段  
    mysql_data_df = fetch_and_preprocess_data('tb_course''course_name''id')  
  
    ifnot mysql_data_df.empty:  
      documents = []  
      for index, row in mysql_data_df.iterrows():  
          # 确保文本和ID有效  
          if pd.notna(row['cleaned_text']) and row['cleaned_text'].strip() and pd.notna(row['course_id']):  
              doc = Document(  
                  page_content=row['cleaned_text'],  
                  metadata={  
                      "course_id": str(row['course_id']), # 建议将ID转换为字符串  
                      "original_content": row['course_name'# 存储原始内容,方便后续追溯  
                  }  
              )  
              documents.append(doc)  
  
      if documents:  
          print(f"✅ 成功构建 {len(documents)} 个 LangChain Document 对象。")  
          print("\n第一个 Document 示例:")  
          print(documents[0])  
      else:  
          print("⚠️ 没有有效的 Document 对象可供处理。")  
      return documents
  • 检索组件(embedding)
  1. 加载mysql或者CSV文件
# load-file.py  
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader  
from frontend.components.sidebar import check_csv_in_folder  
import streamlit as st  
  
from src.api.useCourse import getMysqlDocuments  
  
def load_file():  
  all_docs = []  
  load_files = check_csv_in_folder()  
  mysql_docs = getMysqlDocuments()  
  print(load_files)  
if len(load_files) == 0 :  
    st.warning("未检测到上传文件")  
else:  
    for file in load_files:  
      file_path = f"/data/raw/{file}"  
      # 加载 PDF 文件  
      pdf_loader = PyPDFLoader(file_path)  
      pdf_docs = pdf_loader.load()  
      all_docs = all_docs + pdf_docs  
  
# 合并所有文档  
return all_docs + mysql_docs

2. 数据清洗

# transform-file.py  
from src.model_manage.load_file import load_file  
from langchain.text_splitter import RecursiveCharacterTextSplitter  
from langchain_core.documents import Document  
  
# 过滤和清洗  
def clean_text(doc):  
    # 示例:移除多余的空格和换行符  
    cleaned_content = doc.page_content.replace("\n"" ").strip()  
    # 保留原始元数据(可选)  
    return Document(page_content=cleaned_content, metadata=doc.metadata)  
  
# 文本转换  
def transform_data():  
    # 加载文件  
    docs = load_file()  
    if docs isNone:  
        returnNone  
    # 分块处理  
    text_splitter = RecursiveCharacterTextSplitter(  
        chunk_size=500,   # 每块文本的最大长度  
        chunk_overlap=50# 块之间的重叠长度  
        length_function=len  
    )  
    split_docs = text_splitter.split_documents(docs)  
    # 清洗文本  
    cleaned_docs = [clean_text(doc) for doc in split_docs]  
    return cleaned_docs

0. 向量化处理

# embedding.py  
from langchain_community.embeddings import DashScopeEmbeddings  
from langchain_community.vectorstores import FAISS  
from src.model_manage.transform_file import transform_data  
  
"""  
  向量化文本数据  
  @author: petter  
"""  
def embedding_data():  
  cleaned_docs = transform_data()  
if cleaned_docs isNone:  
    returnNone  
# 向量化文本  
  embeddings = DashScopeEmbeddings(  
      model="text-embedding-v2",  
  )  
# 生成向量数据库  
  vector_db = FAISS.from_documents(cleaned_docs, embeddings)  
  
# 保存到本地(可选)  
  vector_db.save_local("faiss_index")  
return vector_db
  • 增强/上下文注入
import streamlit as st  
from langchain.chains import RetrievalQA  
from langchain_core.prompts import ChatPromptTemplate  
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler  
from langchain_core.runnables import RunnableLambda  
from operator import itemgetter  
from src.model_manage.embedding import embedding_data  
  
vector_db = embedding_data()  
  
# 模型初始化  
# --------------------------  
@st.cache_resource  
def get_model():  
    """初始化并缓存DeepSeek模型"""  
    qa_chain = RetrievalQA.from_chain_type(  
      llm = 大模型名称,  
      chain_type="stuff"# 上下文检索  
      retriever=vector_db.as_retriever(  
            search_type="similarity",  
            search_kwargs={"k"3# 检索最相似的 3 个文档  
      ),  
      return_source_documents=True,  
    )  
    return qa_chain
  • 生成组件
def generate_stream_response(prompt):  
    """生成流式响应内容"""  
    model = get_model()  
    # 流式响应数据增强  
    final_chain = model | RunnableLambda(postprocess_ai_response)  
    chain = final_chain.invoke({  
        "query": prompt  
    })  
    try:  
        for chunk in chain:  
            yield chunk  
    except Exception as e:  
        yield f"⚠️ 请求失败:{str(e)}"
  • 响应内容增强
# 处理AI的回答, 添加自定义信息  
def postprocess_ai_response(input_dict: dict) -> str:  
    ai_response = input_dict["result"# AI的回答  
    source_documents = input_dict.get("source_documents", []) # 检索到的源文档  
  
    custom_info""  
    related_course_ids = set()  
  
    if source_documents:  
        for doc in source_documents:  
            course_id = doc.metadata.get("course_id")  
            if course_id:  
                related_course_ids.add(course_id)  
          
        if related_course_ids:  
            custom_info = f"\n\n--- 本次回答引用了以下课程ID:{', '.join(sorted(list(related_course_ids)))} ---"  
          
    final_output = f"{ai_response}{custom_info}"  
    return final_output

image.png

到这里, RAG核心功能代码已经全部完成了, 至于页面效果, 需要做些调整, ⛽️⛽️

其实这里增强的只有课程ID, 客户端便可以根据课程ID,查询出相对应的课程详情, 展示对应课程或者计划或者其他的方式呈现给用户, 这里相当API

学习资源推荐

如果你想更深入地学习大模型,以下是一些非常有价值的学习资源,这些资源将帮助你从不同角度学习大模型,提升你的实践能力。

本文较长,建议点赞收藏。更多AI大模型应用开发学习视频及资料,在智泊AI