使用Python构建RAG系统

0 阅读3分钟

内容整理自 www.bilibili.com/video/BV1wc…

本机环境:macosx_26_0_x86_64

所以不能使用torch>2.2.0版本
onnxruntime需要使用1.15.0版本
python需要使用python3.11
numpy只能使用1.x版本
sentence-transformers使用2.6.1版本

执行步骤

. /opt/anaconda3/bin/activate && conda activate /opt/anaconda3/envs/rag;
# 此conda环境下时 python3.11
uv init .

uv add "numpy<2"
uv add torch==2.2.0
uv add onnxruntime==1.15.0
uv add sentence_transformers chromadb google-genai python-dotenv

uv remove sentence-transformers
uv add sentence-transformers==2.6.1

uv run --with jupyter jupyter lab

image.png

需要在google aistudio申请API key; 并在项目目录下创建 .env 文件,内容为

GEMINI_API_KEY=此处填写申请的key值

pyproject.toml

[project]
name = "rag"
version = "0.1.0"
description = "构建RAG系统"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
    "chromadb>=1.5.0",
    "google-genai>=1.64.0",
    "numpy<2",
    "onnxruntime==1.15.0",
    "python-dotenv>=1.2.1",
    "sentence-transformers==2.6.1",
    "torch==2.2.0",
]

common.py

import time
from typing import List
from sentence_transformers import SentenceTransformer

start = time.perf_counter()
embedding_model = SentenceTransformer("shibing624/text2vec-base-chinese")
print(f"加载embedding_model: {(time.perf_counter() - start):.4f} 秒")


def embed_chunk(chunk: str) -> List[float]:
    """
    片段文本向量化
    :param chunk:
    :return:
    """
    embedding = embedding_model.encode(chunk, normalize_embeddings=True)
    return embedding.tolist()

save_step.py

import time
from typing import List
from common import embed_chunk
import chromadb

start = time.perf_counter()
chromadb_client = chromadb.PersistentClient("./chroma.db")
chromadb_collection = chromadb_client.get_or_create_collection(name="default")
print(f"加载chromadb: {(time.perf_counter() - start):.4f} 秒")


def split_into_chunks(doc_file: str) -> List[str]:
    """
    分片
    :param doc_file:
    :return:
    """
    with open(doc_file, 'r') as file:
        content = file.read()

    return [chunk for chunk in content.split("\n\n")]


def save_embeddings(chunks: List[str], embeddings: List[List[float]]) -> None:
    """
    索引
    :param chunks:
    :param embeddings:
    :return:
    """
    for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
        chromadb_collection.add(
            documents=[chunk],
            embeddings=[embedding],
            ids=[str(i)]
        )


def save_step():
    """
    系统初始化时执行一遍即可
    :return:
    """
    chunks = split_into_chunks("doc.md")
    # for i, chunk in enumerate(chunks):
    #     print(f"[{i}] {chunk}\n")
    # embedding = embed_chunk("测试内容")
    # print(len(embedding))
    # print(embedding)
    embeddings = [embed_chunk(chunk) for chunk in chunks]
    # print(len(embeddings))
    # print(embeddings[0])
    save_embeddings(chunks, embeddings)

generate_step.py

import time
from typing import List
from common import embed_chunk
from save_step import chromadb_collection
from sentence_transformers import CrossEncoder
from dotenv import load_dotenv
from google import genai

load_dotenv()
google_client = genai.Client()


def retrieve(query: str, top_k: int) -> List[str]:
    """
    召回
    :param query:
    :param top_k:
    :return:
    """
    query_embedding = embed_chunk(query)
    results = chromadb_collection.query(
        query_embeddings=[query_embedding],
        n_results=top_k
    )
    return results['documents'][0]


def rerank(query: str, retrieved_chunks: List[str], top_k: int) -> List[str]:
    """
    重排
    :param query:
    :param retrieved_chunks:
    :param top_k:
    :return:
    """
    start = time.perf_counter()
    cross_encoder = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1')
    print(f"加载cross_encoder: {(time.perf_counter() - start):.4f} 秒")
    pairs = [(query, chunk) for chunk in retrieved_chunks]
    scores = cross_encoder.predict(pairs)

    scored_chunks = list(zip(retrieved_chunks, scores))
    scored_chunks.sort(key=lambda x: x[1], reverse=True)

    return [chunk for chunk, _ in scored_chunks][:top_k]


def generate(query: str, chunks: List[str]) -> str:
    """
    生成
    :param query:
    :param chunks:
    :return:
    """
    chunks_text = "\n\n".join(chunks)
    prompt = f"""你是一位知识助手,请根据用户的问题和下列片段生成准确的回答。

用户问题: {query}

相关片段:{chunks_text}

请基于上述内容作答,不要编造信息。"""

    # print(f"{prompt}\n\n---\n")

    start = time.perf_counter()
    response = google_client.models.generate_content(
        model="gemini-2.5-flash",
        contents=prompt
    )
    print(f"gemini响应: {(time.perf_counter() - start):.4f} 秒")

    return response.text


if __name__ == "__main__":
    query = "哆啦A梦使用的3个秘密道具分别是什么?"
    retrieved_chunks = retrieve(query, 5)
    # for i, chunk in enumerate(retrieved_chunks):
    #     print(f"[{i}] {chunk}\n")
    reranked_chunks = rerank(query, retrieved_chunks, 3)
    # for i, chunk in enumerate(reranked_chunks):
    #     print(f"[{i}] {chunk}\n")
    answer = generate(query, reranked_chunks)
    print(answer)

测试通过.结果为:

加载embedding_model: 5.8553 秒
加载chromadb: 0.1427 秒
加载cross_encoder: 4.1846 秒
gemini响应: 11.8845 秒
根据片段,哆啦A梦使用的3个秘密道具分别是:

1.  **复制斗篷**:可以临时赋予超级战力。
2.  **时间停止手表**:能暂停时间五秒。
3.  **精神与时光屋便携版**:可在短时间内完成长时间的修行(例如,一分钟中完成一年修行)。