LangChain 深度剖析技术文档

75 阅读18分钟

LangChain 深度剖析技术文档

目录

  1. 概述
  2. 核心架构
  3. 核心组件深度解析
  4. 高级特性
  5. 实战应用
  6. 性能优化
  7. 最佳实践

概述

什么是 LangChain?

LangChain 是一个开源框架,专为开发基于大型语言模型(LLM)的应用程序而设计。它提供了一套完整的工具链,使开发者能够构建复杂的、上下文感知的智能应用。

核心价值

  • 模块化设计:提供可组合的组件,支持灵活的应用架构
  • 链式处理:支持复杂的多步骤推理和处理流程
  • 数据集成:无缝连接各种数据源和外部服务
  • 模型抽象:统一的接口支持多种 LLM 提供商

适用场景

  • 智能问答系统
  • 文档分析与总结
  • 代码生成与分析
  • 数据检索与处理
  • 多模态应用开发

核心架构

整体架构图

┌─────────────────────────────────────────────────────────────┐
│                    LangChain 架构                            │
├─────────────────────────────────────────────────────────────┤
│  Application Layer                                          │
│  ┌─────────────┐ ┌─────────────┐ ┌─────────────┐            │
│  │   Chains    │ │   Agents    │ │  Callbacks  │            │
│  └─────────────┘ └─────────────┘ └─────────────┘            │
├─────────────────────────────────────────────────────────────┤
│  Core Components Layer                                      │
│  ┌─────────────┐ ┌─────────────┐ ┌─────────────┐            │
│  │   Models    │ │   Prompts   │ │   Memory    │            │
│  └─────────────┘ └─────────────┘ └─────────────┘            │
│  ┌─────────────┐ ┌─────────────┐ ┌─────────────┐            │
│  │  Indexes    │ │    Tools    │ │   Output    │            │
│  │             │ │             │ │   Parsers   │            │
│  └─────────────┘ └─────────────┘ └─────────────┘            │
├─────────────────────────────────────────────────────────────┤
│  Data Layer                                                 │
│  ┌─────────────┐ ┌─────────────┐ ┌─────────────┐            │
│  │ Vector DBs  │ │ Document    │ │ External    │            │
│  │             │ │ Loaders     │ │ APIs        │            │
│  └─────────────┘ └─────────────┘ └─────────────┘            │
└─────────────────────────────────────────────────────────────┘

设计原则

  1. 组合性:所有组件都可以独立使用或组合使用
  2. 可扩展性:支持自定义组件和插件
  3. 标准化:统一的接口和数据格式
  4. 可观测性:完整的日志和监控支持

核心组件深度解析

1. Models(模型层)

Qwen 模型集成
from langchain.llms import Tongyi
from langchain.chat_models import ChatTongyi
import os

# 设置 DashScope API Key
os.environ["DASHSCOPE_API_KEY"] = "your-dashscope-api-key"

# Qwen 基础模型配置
qwen_llm = Tongyi(
    model_name="qwen-turbo",      # 或 "qwen-plus", "qwen-max"
    temperature=0.7,
    max_tokens=1024,
    top_p=0.8,
    streaming=False
)

# Qwen 聊天模型配置
qwen_chat = ChatTongyi(
    model_name="qwen-plus",
    temperature=0.7,
    max_tokens=1024,
    streaming=True  # 支持流式输出
)

# 本地 Qwen 模型(使用 HuggingFace)
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.llms import HuggingFacePipeline

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Chat")
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-7B-Chat",
    device_map="auto",
    torch_dtype="auto"
)

local_qwen = HuggingFacePipeline(
    model=model,
    tokenizer=tokenizer,
    model_kwargs={
        "temperature": 0.7,
        "max_new_tokens": 1024,
        "do_sample": True
    }
)
Qwen 模型选择指南
  • qwen-turbo: 速度快,成本低,适合简单任务
  • qwen-plus: 平衡性能和成本,适合大多数应用
  • qwen-max: 最强性能,适合复杂推理任务
  • qwen-vl: 多模态模型,支持图像理解
LLM 抽象
from langchain.llms import Tongyi
from langchain.chat_models import ChatTongyi
from langchain.llms import HuggingFacePipeline

# Qwen 基础 LLM
llm = Tongyi(
    model_name="qwen-turbo",
    temperature=0.7,
    max_tokens=256,
    top_p=1
)

# Qwen 聊天模型
chat_model = ChatTongyi(
    model_name="qwen-plus",
    temperature=0.7,
    max_tokens=256
)

# 本地 Qwen 模型
local_llm = HuggingFacePipeline.from_model_id(
    model_id="Qwen/Qwen2-7B-Chat",
    task="text-generation",
    model_kwargs={"temperature": 0.7, "max_length": 256}
)
模型配置最佳实践
class ModelConfig:
    def __init__(self):
        self.temperature = 0.1  # 低温度保证一致性
        self.max_tokens = 1024
        self.retry_config = {
            "max_retries": 3,
            "backoff_factor": 2,
            "timeout": 30
        }
        
    def get_llm(self, model_type="qwen"):
        if model_type == "qwen":
            return Tongyi(
                model_name="qwen-turbo",
                temperature=self.temperature,
                max_tokens=self.max_tokens,
                request_timeout=self.retry_config["timeout"]
            )

2. Prompts(提示模板)

基础提示模板
from langchain.prompts import PromptTemplate, ChatPromptTemplate

# 基础模板
basic_template = PromptTemplate(
    input_variables=["product", "audience"],
    template="""
    为 {product} 写一个针对 {audience} 的营销文案。
    
    要求:
    - 突出产品优势
    - 符合目标受众特点
    - 语言生动有趣
    
    营销文案:
    """
)

# 聊天模板
chat_template = ChatPromptTemplate.from_messages([
    ("system", "你是一个专业的技术文档撰写专家。"),
    ("human", "请为 {technology} 技术写一份详细的介绍文档。"),
    ("ai", "我会为您创建一份全面的技术文档,包括以下部分:"),
    ("human", "重点关注 {focus_area} 方面的内容。")
])
动态提示生成
class DynamicPromptGenerator:
    def __init__(self):
        self.templates = {
            "analysis": """
            请分析以下{data_type}数据:
            
            数据内容:{data}
            
            分析维度:
            {analysis_dimensions}
            
            请提供详细的分析报告。
            """,
            "summary": """
            请总结以下内容:
            
            {content}
            
            总结要求:
            - 长度不超过{max_length}字
            - 突出关键信息
            - 保持客观中性
            """
        }
    
    def generate_prompt(self, template_type, **kwargs):
        template = PromptTemplate(
            template=self.templates[template_type],
            input_variables=list(kwargs.keys())
        )
        return template.format(**kwargs)

3. Memory(记忆系统)

记忆类型详解
from langchain.memory import (
    ConversationBufferMemory,
    ConversationBufferWindowMemory,
    ConversationSummaryMemory,
    ConversationSummaryBufferMemory
)

# 缓冲记忆 - 保存完整对话历史
buffer_memory = ConversationBufferMemory(
    memory_key="chat_history",
    return_messages=True
)

# 窗口记忆 - 只保留最近N轮对话
window_memory = ConversationBufferWindowMemory(
    k=5,  # 保留最近5轮对话
    memory_key="chat_history",
    return_messages=True
)

# 摘要记忆 - 对历史对话进行摘要
summary_memory = ConversationSummaryMemory(
    llm=llm,
    memory_key="chat_history",
    return_messages=True
)

# 摘要缓冲记忆 - 结合摘要和缓冲
summary_buffer_memory = ConversationSummaryBufferMemory(
    llm=llm,
    max_token_limit=1000,
    memory_key="chat_history",
    return_messages=True
)

4. Chains(链式处理)

基础链类型
from langchain.chains import (
    LLMChain,
    SimpleSequentialChain,
    SequentialChain,
    ConversationChain
)

# 基础 LLM 链
basic_chain = LLMChain(
    llm=llm,
    prompt=basic_template,
    verbose=True
)

# 简单顺序链
chain1 = LLMChain(
    llm=llm,
    prompt=PromptTemplate(
        input_variables=["topic"],
        template="为 {topic} 写一个大纲。"
    ),
    output_key="outline"
)

chain2 = LLMChain(
    llm=llm,
    prompt=PromptTemplate(
        input_variables=["outline"],
        template="根据以下大纲写一篇详细文章:\n{outline}"
    ),
    output_key="article"
)

simple_sequential_chain = SimpleSequentialChain(
    chains=[chain1, chain2],
    verbose=True
)

5. Tools(工具集成)

内置工具
from langchain.tools import (
    DuckDuckGoSearchRun,
    WikipediaQueryRun,
    PythonREPLTool,
    ShellTool
)

# 搜索工具
search_tool = DuckDuckGoSearchRun()

# 维基百科查询
wiki_tool = WikipediaQueryRun()

# Python 执行环境
python_tool = PythonREPLTool()

# Shell 命令执行
shell_tool = ShellTool()

tools = [search_tool, wiki_tool, python_tool]
自定义工具开发
from langchain.tools import BaseTool
from typing import Optional, Type
from pydantic import BaseModel, Field

class DatabaseQueryInput(BaseModel):
    query: str = Field(description="SQL查询语句")
    database: str = Field(description="数据库名称")

class DatabaseQueryTool(BaseTool):
    name = "database_query"
    description = "执行数据库查询的工具"
    args_schema: Type[BaseModel] = DatabaseQueryInput
    
    def _run(
        self, 
        query: str, 
        database: str, 
        run_manager: Optional[Any] = None
    ) -> str:
        """执行数据库查询"""
        try:
            connection = self._get_connection(database)
            result = connection.execute(query)
            return str(result.fetchall())
        except Exception as e:
            return f"查询错误: {str(e)}"

6. Agents(智能代理)

Agent 类型详解
from langchain.agents import (
    initialize_agent,
    AgentType
)

# ReAct Agent
react_agent = initialize_agent(
    tools=tools,
    llm=llm,
    agent=AgentType.REACT_DOCSTORE,
    verbose=True,
    max_iterations=3,
    early_stopping_method="generate"
)

# OpenAI Functions Agent  
functions_agent = initialize_agent(
    tools=tools,
    llm=chat_model,
    agent=AgentType.OPENAI_FUNCTIONS,
    verbose=True
)

# Zero-shot React Agent
zero_shot_agent = initialize_agent(
    tools=tools,
    llm=llm,
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    verbose=True
)

高级特性

1. 向量存储与检索

from langchain.vectorstores import Chroma, FAISS
from langchain.embeddings import DashScopeEmbeddings, HuggingFaceEmbeddings
from langchain.retrievers import VectorStoreRetriever

# Qwen 嵌入模型
qwen_embeddings = DashScopeEmbeddings(
    model="text-embedding-v1",  # 或 "text-embedding-v2"
    dashscope_api_key="your-dashscope-api-key"
)

# 本地嵌入模型(替代方案)
local_embeddings = HuggingFaceEmbeddings(
    model_name="BAAI/bge-small-zh-v1.5",  # 中文优化模型
    model_kwargs={'device': 'cpu'},
    encode_kwargs={'normalize_embeddings': True}
)

# Chroma 向量数据库
chroma_db = Chroma(
    collection_name="qwen_documents",
    embedding_function=qwen_embeddings,
    persist_directory="./chroma_db"
)

# FAISS 向量数据库
faiss_db = FAISS.from_documents(
    documents=docs,
    embedding=qwen_embeddings
)

# 向量检索器
retriever = VectorStoreRetriever(
    vectorstore=chroma_db,
    search_type="similarity",
    search_kwargs={"k": 4}
)

2. 回调系统

from langchain.callbacks import StdOutCallbackHandler
from langchain.callbacks.base import BaseCallbackHandler

class CustomCallbackHandler(BaseCallbackHandler):
    """自定义回调处理器"""
    
    def __init__(self):
        self.logs = []
        
    def on_llm_start(self, serialized, prompts, **kwargs):
        """LLM开始执行时调用"""
        self.logs.append({
            "event": "llm_start",
            "prompts": prompts,
            "timestamp": time.time()
        })
        
    def on_llm_end(self, response, **kwargs):
        """LLM执行结束时调用"""
        self.logs.append({
            "event": "llm_end",
            "response": str(response),
            "timestamp": time.time()
        })

3. 输出解析器

from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field

class PersonInfo(BaseModel):
    name: str = Field(description="人物姓名")
    age: int = Field(description="年龄")
    occupation: str = Field(description="职业")
    skills: List[str] = Field(description="技能列表")

parser = PydanticOutputParser(pydantic_object=PersonInfo)

# 获取格式化指令
format_instructions = parser.get_format_instructions()

# 解析输出
parsed_result = parser.parse(llm_output)

实战应用

1. 基于 Qwen 的智能问答系统

from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.llms import Tongyi
from langchain.embeddings import DashScopeEmbeddings
from langchain.vectorstores import FAISS
import os

# 设置 API Key
os.environ["DASHSCOPE_API_KEY"] = "your-dashscope-api-key"

# 初始化 Qwen 模型
qwen_llm = Tongyi(
    model_name="qwen-plus",
    temperature=0.1,  # 低温度保证答案稳定性
    max_tokens=1024
)

# 初始化嵌入模型
qwen_embeddings = DashScopeEmbeddings(model="text-embedding-v2")

# 加载文档
loader = TextLoader("knowledge_base.txt", encoding="utf-8")
documents = loader.load()

# 文档分割
text_splitter = CharacterTextSplitter(
    chunk_size=500,    # 适合中文的块大小
    chunk_overlap=50,  # 重叠部分
    separator="\n\n"   # 按段落分割
)
docs = text_splitter.split_documents(documents)

# 创建向量数据库
vectorstore = FAISS.from_documents(docs, qwen_embeddings)

# 构建问答链
qa_chain = RetrievalQA.from_chain_type(
    llm=qwen_llm,
    chain_type="stuff",
    retriever=vectorstore.as_retriever(
        search_kwargs={"k": 3}  # 返回最相关的3个文档
    ),
    return_source_documents=True,  # 返回源文档
    verbose=True
)

# 使用问答系统
question = "什么是机器学习?"
result = qa_chain({"query": question})

print(f"问题:{question}")
print(f"答案:{result['result']}")
print(f"源文档数量:{len(result['source_documents'])}")

2. 基于 Qwen 的文档总结应用

from langchain.chains.summarize import load_summarize_chain
from langchain.prompts import PromptTemplate
from langchain.llms import Tongyi

# 初始化 Qwen 模型
qwen_llm = Tongyi(
    model_name="qwen-turbo",  # 总结任务使用更快的模型
    temperature=0.3,
    max_tokens=512
)

# 自定义总结提示模板
summarize_template = """
请对以下文档进行简洁准确的总结:

{text}

总结要求:
1. 提取关键信息和主要观点
2. 保持客观中立的语调
3. 不超过200字

总结:
"""

summarize_prompt = PromptTemplate(
    template=summarize_template,
    input_variables=["text"]
)

# 加载总结链(使用map_reduce方法处理长文档)
summarize_chain = load_summarize_chain(
    llm=qwen_llm,
    chain_type="map_reduce",
    map_prompt=summarize_prompt,
    combine_prompt=summarize_prompt,
    verbose=True
)

# 执行总结
summary = summarize_chain.run(docs)
print(f"文档总结:{summary}")

# 单文档总结示例
single_doc_chain = load_summarize_chain(
    llm=qwen_llm,
    chain_type="stuff",  # 对于较短的文档
    prompt=summarize_prompt
)

if len(docs) > 0 and len(docs[0].page_content) < 2000:
    single_summary = single_doc_chain.run([docs[0]])
    print(f"单文档总结:{single_summary}")

3. 基于 Qwen 的代码分析助手

from langchain.agents import create_python_agent
from langchain.agents.agent_toolkits import PythonREPLTool
from langchain.llms import Tongyi
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

# 初始化 Qwen 模型(代码分析使用更强的模型)
qwen_code_llm = Tongyi(
    model_name="qwen-max",
    temperature=0.1,  # 低温度保证代码分析的准确性
    max_tokens=2048
)

# 创建Python代理
code_analysis_agent = create_python_agent(
    llm=qwen_code_llm,
    tool=PythonREPLTool(),
    verbose=True
)

# 代码分析任务
analysis_prompt = """
分析以下Python代码的时间复杂度和空间复杂度:

def fibonacci(n):
    if n <= 1:
        return n
    return fibonacci(n-1) + fibonacci(n-2)

请提供:
1. 时间复杂度分析
2. 空间复杂度分析  
3. 性能优化建议
4. 优化后的代码实现
"""

result = code_analysis_agent.run(analysis_prompt)
print(result)

# 自定义代码审查链
code_review_template = PromptTemplate(
    input_variables=["code"],
    template="""
    作为一个高级Python开发者,请对以下代码进行全面审查:
    
    {code}
    
    请从以下维度进行分析:
    1. **代码质量**: 可读性、命名规范、注释
    2. **性能分析**: 时间/空间复杂度、潜在瓶颈
    3. **安全性**: 输入验证、异常处理
    4. **最佳实践**: 是否遵循Python最佳实践
    5. **优化建议**: 具体的改进方案
    
    请提供详细的分析报告和优化后的代码。
    """
)

code_review_chain = LLMChain(
    llm=qwen_code_llm,
    prompt=code_review_template,
    verbose=True
)

# 执行代码审查
code_to_review = """
def bubble_sort(arr):
    n = len(arr)
    for i in range(n):
        for j in range(0, n-i-1):
            if arr[j] > arr[j+1]:
                arr[j], arr[j+1] = arr[j+1], arr[j]
    return arr
"""

review_result = code_review_chain.run(code=code_to_review)
print(f"代码审查结果:\n{review_result}")

性能优化

1. 缓存策略

from langchain.cache import InMemoryCache, SQLiteCache
import langchain

# 内存缓存
langchain.llm_cache = InMemoryCache()

# SQLite缓存
langchain.llm_cache = SQLiteCache(database_path=".langchain.db")

2. 批处理优化

# 批量处理文档
batch_size = 10
results = []

for i in range(0, len(documents), batch_size):
    batch = documents[i:i+batch_size]
    batch_results = llm.generate([doc.page_content for doc in batch])
    results.extend(batch_results.generations)

3. 异步处理

import asyncio
from langchain.llms import Tongyi

async def process_documents_async(documents):
    llm = Tongyi(model_name="qwen-turbo")
    tasks = []
    
    for doc in documents:
        task = llm.agenerate([doc.page_content])
        tasks.append(task)
    
    results = await asyncio.gather(*tasks)
    return results

最佳实践

1. 错误处理

from langchain.schema import OutputParserException

def safe_chain_execution(chain, input_data):
    try:
        result = chain.run(input_data)
        return {"success": True, "result": result}
    except OutputParserException as e:
        return {"success": False, "error": f"解析错误: {e}"}
    except Exception as e:
        return {"success": False, "error": f"执行错误: {e}"}

2. 配置管理

from dataclasses import dataclass
from typing import Optional

@dataclass
class LangChainConfig:
    dashscope_api_key: str
    model_name: str = "qwen-turbo"
    temperature: float = 0.7
    max_tokens: int = 1000
    chunk_size: int = 1000
    chunk_overlap: int = 200
    vector_store_path: Optional[str] = None
    
    @classmethod
    def from_env(cls):
        return cls(
            dashscope_api_key=os.getenv("DASHSCOPE_API_KEY"),
            model_name=os.getenv("MODEL_NAME", "qwen-turbo"),
            temperature=float(os.getenv("TEMPERATURE", "0.7"))
        )

3. 监控和日志

import logging
from langchain.callbacks import get_openai_callback

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 成本监控 (Qwen模型)
try:
    result = chain.run(input_text)
    logger.info(f"Qwen模型调用成功")
except Exception as e:
    logger.error(f"Qwen模型调用失败: {e}")

4. 安全性考虑

import re

def sanitize_input(user_input: str) -> str:
    """清理用户输入,防止注入攻击"""
    # 移除潜在的恶意代码
    cleaned = re.sub(r'[<>\"\'&]', '', user_input)
    
    # 限制长度
    if len(cleaned) > 1000:
        cleaned = cleaned[:1000]
    
    return cleaned

def validate_output(output: str) -> bool:
    """验证输出内容"""
    # 检查敏感信息
    sensitive_patterns = [
        r'\b\d{4}-\d{4}-\d{4}-\d{4}\b',  # 信用卡号
        r'\b\d{3}-\d{2}-\d{4}\b',        # 社会安全号
    ]
    
    for pattern in sensitive_patterns:
        if re.search(pattern, output):
            return False
    
    return True

总结

LangChain 作为一个强大的 LLM 应用开发框架,提供了完整的工具链和抽象层,大大简化了智能应用的开发过程。通过其模块化的设计,开发者可以快速构建复杂的AI应用,同时保持代码的可维护性和可扩展性。

结合 Qwen 大模型的强大能力,LangChain 可以为中文应用场景提供更好的支持。Qwen 模型在中文理解、代码生成、逻辑推理等方面表现优异,特别适合构建面向中文用户的智能应用。

Qwen + LangChain 开发建议

  1. 模型选择策略

    • 简单任务使用 qwen-turbo(速度快,成本低)
    • 复杂推理使用 qwen-plus(平衡性能和成本)
    • 高要求任务使用 qwen-max(最强性能)
  2. 成本优化

    • 启用 LangChain 缓存机制
    • 合理设置 max_tokens 参数
    • 使用批处理优化多文档处理
  3. 中文优化

    • 针对中文优化文档分割策略
    • 使用中文优化的嵌入模型
    • 设计符合中文表达习惯的提示模板
  4. 安全性考虑

    • 对用户输入进行严格验证
    • 实施输出内容过滤
    • 保护 API 密钥安全

掌握 LangChain 的核心概念和 Qwen 模型的最佳实践,将帮助开发者更好地利用大语言模型的能力,构建真正有价值的中文智能应用。

推荐学习路径

  1. 基础入门:熟悉 LangChain 核心组件和 Qwen API
  2. 实践应用:构建简单的问答和总结应用
  3. 进阶开发:开发复杂的多步骤推理应用
  4. 生产部署:掌握性能优化和监控技巧
  5. 持续优化:根据用户反馈持续改进应用效果

附录:完整 RAG 系统核心代码实现

以下是一个基于 LangChain 和 Qwen 模型的完整 RAG(检索增强生成)系统实现,包含文档加载、向量化、检索和生成的完整流程。

RAG 系统架构

import os
import logging
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from pathlib import Path

# LangChain 核心组件
from langchain.llms import Tongyi
from langchain.chat_models import ChatTongyi
from langchain.embeddings import DashScopeEmbeddings
from langchain.vectorstores import Chroma, FAISS
from langchain.document_loaders import (
    TextLoader, PDFLoader, DirectoryLoader,
    UnstructuredMarkdownLoader, CSVLoader
)
from langchain.text_splitter import (
    RecursiveCharacterTextSplitter,
    CharacterTextSplitter
)
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import PromptTemplate
from langchain.retrievers import (
    VectorStoreRetriever,
    MultiQueryRetriever,
    ContextualCompressionRetriever
)
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.callbacks import StdOutCallbackHandler

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class RAGConfig:
    """RAG 系统配置类"""
    # API 配置
    dashscope_api_key: str
    
    # 模型配置
    llm_model_name: str = "qwen-plus"
    embedding_model_name: str = "text-embedding-v2"
    temperature: float = 0.1
    max_tokens: int = 2048
    
    # 文档处理配置
    chunk_size: int = 500
    chunk_overlap: int = 50
    max_docs_per_query: int = 4
    
    # 向量存储配置
    vector_store_type: str = "chroma"  # 或 "faiss"
    persist_directory: str = "./rag_vectorstore"
    collection_name: str = "rag_documents"
    
    # 检索配置
    search_type: str = "similarity"  # 或 "mmr"
    search_kwargs: Dict[str, Any] = None
    
    def __post_init__(self):
        if self.search_kwargs is None:
            self.search_kwargs = {"k": self.max_docs_per_query}
        
        # 设置环境变量
        os.environ["DASHSCOPE_API_KEY"] = self.dashscope_api_key

class DocumentProcessor:
    """文档处理器 - 负责加载和分割文档"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=config.chunk_size,
            chunk_overlap=config.chunk_overlap,
            separators=["\n\n", "\n", "。", "!", "?", ";", ".", "!", "?", " ", ""]
        )
    
    def load_documents(self, file_paths: List[str]) -> List[Any]:
        """加载多种格式的文档"""
        documents = []
        
        for file_path in file_paths:
            path = Path(file_path)
            
            try:
                if path.suffix.lower() == '.pdf':
                    loader = PDFLoader(str(path))
                elif path.suffix.lower() == '.txt':
                    loader = TextLoader(str(path), encoding='utf-8')
                elif path.suffix.lower() == '.md':
                    loader = UnstructuredMarkdownLoader(str(path))
                elif path.suffix.lower() == '.csv':
                    loader = CSVLoader(str(path))
                elif path.is_dir():
                    loader = DirectoryLoader(
                        str(path),
                        glob="**/*.{txt,md,pdf}",
                        show_progress=True
                    )
                else:
                    logger.warning(f"不支持的文件格式: {path.suffix}")
                    continue
                
                docs = loader.load()
                documents.extend(docs)
                logger.info(f"成功加载 {len(docs)} 个文档从 {file_path}")
                
            except Exception as e:
                logger.error(f"加载文档失败 {file_path}: {e}")
        
        return documents
    
    def split_documents(self, documents: List[Any]) -> List[Any]:
        """分割文档为小块"""
        try:
            chunks = self.text_splitter.split_documents(documents)
            logger.info(f"文档分割完成,共生成 {len(chunks)} 个文档块")
            return chunks
        except Exception as e:
            logger.error(f"文档分割失败: {e}")
            return []
    
    def process_documents(self, file_paths: List[str]) -> List[Any]:
        """完整的文档处理流程"""
        documents = self.load_documents(file_paths)
        if not documents:
            raise ValueError("没有成功加载任何文档")
        
        chunks = self.split_documents(documents)
        if not chunks:
            raise ValueError("文档分割失败")
        
        return chunks

class VectorStoreManager:
    """向量存储管理器"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        self.embeddings = DashScopeEmbeddings(
            model=config.embedding_model_name
        )
        self.vectorstore = None
    
    def create_vectorstore(self, documents: List[Any]) -> Any:
        """创建向量存储"""
        try:
            if self.config.vector_store_type.lower() == "chroma":
                self.vectorstore = Chroma.from_documents(
                    documents=documents,
                    embedding=self.embeddings,
                    collection_name=self.config.collection_name,
                    persist_directory=self.config.persist_directory
                )
                self.vectorstore.persist()
                
            elif self.config.vector_store_type.lower() == "faiss":
                self.vectorstore = FAISS.from_documents(
                    documents=documents,
                    embedding=self.embeddings
                )
                # 保存 FAISS 索引
                faiss_path = Path(self.config.persist_directory)
                faiss_path.mkdir(exist_ok=True)
                self.vectorstore.save_local(str(faiss_path))
            
            else:
                raise ValueError(f"不支持的向量存储类型: {self.config.vector_store_type}")
            
            logger.info(f"向量存储创建成功,共索引 {len(documents)} 个文档块")
            return self.vectorstore
            
        except Exception as e:
            logger.error(f"向量存储创建失败: {e}")
            raise
    
    def load_vectorstore(self) -> Any:
        """加载已存在的向量存储"""
        try:
            if self.config.vector_store_type.lower() == "chroma":
                self.vectorstore = Chroma(
                    collection_name=self.config.collection_name,
                    embedding_function=self.embeddings,
                    persist_directory=self.config.persist_directory
                )
                
            elif self.config.vector_store_type.lower() == "faiss":
                self.vectorstore = FAISS.load_local(
                    self.config.persist_directory,
                    self.embeddings
                )
            
            logger.info("向量存储加载成功")
            return self.vectorstore
            
        except Exception as e:
            logger.error(f"向量存储加载失败: {e}")
            raise
    
    def get_retriever(self, retriever_type: str = "basic") -> Any:
        """获取检索器"""
        if not self.vectorstore:
            raise ValueError("向量存储未初始化")
        
        base_retriever = VectorStoreRetriever(
            vectorstore=self.vectorstore,
            search_type=self.config.search_type,
            search_kwargs=self.config.search_kwargs
        )
        
        if retriever_type == "basic":
            return base_retriever
        
        elif retriever_type == "multi_query":
            # 需要 LLM 来生成多个查询
            llm = Tongyi(
                model_name=self.config.llm_model_name,
                temperature=0.1
            )
            return MultiQueryRetriever.from_llm(
                retriever=base_retriever,
                llm=llm
            )
        
        elif retriever_type == "compression":
            # 压缩检索器
            llm = Tongyi(
                model_name=self.config.llm_model_name,
                temperature=0.1
            )
            compressor = LLMChainExtractor.from_llm(llm)
            return ContextualCompressionRetriever(
                base_compressor=compressor,
                base_retriever=base_retriever
            )
        
        else:
            raise ValueError(f"不支持的检索器类型: {retriever_type}")

class RAGChatBot:
    """RAG 聊天机器人"""
    
    def __init__(self, config: RAGConfig, retriever: Any):
        self.config = config
        self.retriever = retriever
        
        # 初始化 LLM
        self.llm = ChatTongyi(
            model_name=config.llm_model_name,
            temperature=config.temperature,
            max_tokens=config.max_tokens
        )
        
        # 初始化记忆
        self.memory = ConversationBufferWindowMemory(
            k=5,  # 保留最近5轮对话
            memory_key="chat_history",
            output_key="answer",
            return_messages=True
        )
        
        # 设置自定义提示模板
        self.qa_prompt = self._create_qa_prompt()
        
        # 创建对话链
        self.conversation_chain = ConversationalRetrievalChain.from_llm(
            llm=self.llm,
            retriever=self.retriever,
            memory=self.memory,
            return_source_documents=True,
            verbose=True,
            combine_docs_chain_kwargs={"prompt": self.qa_prompt}
        )
    
    def _create_qa_prompt(self) -> PromptTemplate:
        """创建问答提示模板"""
        template = """
你是一个专业的AI助手,基于提供的上下文信息来回答用户问题。请遵循以下准则:

1. 仔细阅读上下文信息,基于这些信息来回答问题
2. 如果上下文信息不足以回答问题,请明确说明
3. 保持回答的准确性和客观性
4. 用简洁清晰的中文回答
5. 如果可能,提供具体的例子或详细解释

上下文信息:
{context}

历史对话:
{chat_history}

用户问题:{question}

请基于上下文信息回答用户问题:
        """.strip()
        
        return PromptTemplate(
            input_variables=["context", "chat_history", "question"],
            template=template
        )
    
    def ask(self, question: str) -> Dict[str, Any]:
        """询问问题"""
        try:
            logger.info(f"用户问题: {question}")
            
            response = self.conversation_chain({"question": question})
            
            result = {
                "answer": response["answer"],
                "source_documents": response.get("source_documents", []),
                "chat_history": response.get("chat_history", [])
            }
            
            logger.info(f"回答生成成功,使用了 {len(result['source_documents'])} 个参考文档")
            return result
            
        except Exception as e:
            logger.error(f"问答失败: {e}")
            return {
                "answer": f"抱歉,处理您的问题时出现错误: {str(e)}",
                "source_documents": [],
                "chat_history": []
            }
    
    def get_relevant_documents(self, query: str, k: int = None) -> List[Any]:
        """获取相关文档"""
        if k:
            # 临时修改检索数量
            original_k = self.retriever.search_kwargs.get("k", 4)
            self.retriever.search_kwargs["k"] = k
            
        docs = self.retriever.get_relevant_documents(query)
        
        if k:
            # 恢复原始设置
            self.retriever.search_kwargs["k"] = original_k
        
        return docs
    
    def clear_memory(self):
        """清除对话记忆"""
        self.memory.clear()
        logger.info("对话记忆已清除")

class RAGSystem:
    """完整的 RAG 系统"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        self.doc_processor = DocumentProcessor(config)
        self.vector_manager = VectorStoreManager(config)
        self.chatbot = None
    
    def initialize_from_documents(self, file_paths: List[str], retriever_type: str = "basic"):
        """从文档初始化 RAG 系统"""
        logger.info("开始初始化 RAG 系统...")
        
        # 1. 处理文档
        documents = self.doc_processor.process_documents(file_paths)
        
        # 2. 创建向量存储
        self.vector_manager.create_vectorstore(documents)
        
        # 3. 获取检索器
        retriever = self.vector_manager.get_retriever(retriever_type)
        
        # 4. 初始化聊天机器人
        self.chatbot = RAGChatBot(self.config, retriever)
        
        logger.info("RAG 系统初始化完成!")
    
    def load_existing_system(self, retriever_type: str = "basic"):
        """加载已存在的 RAG 系统"""
        logger.info("加载现有 RAG 系统...")
        
        # 1. 加载向量存储
        self.vector_manager.load_vectorstore()
        
        # 2. 获取检索器
        retriever = self.vector_manager.get_retriever(retriever_type)
        
        # 3. 初始化聊天机器人
        self.chatbot = RAGChatBot(self.config, retriever)
        
        logger.info("RAG 系统加载完成!")
    
    def chat(self, question: str) -> Dict[str, Any]:
        """与 RAG 系统对话"""
        if not self.chatbot:
            raise ValueError("RAG 系统尚未初始化")
        
        return self.chatbot.ask(question)
    
    def search_documents(self, query: str, k: int = 4) -> List[Any]:
        """搜索相关文档"""
        if not self.chatbot:
            raise ValueError("RAG 系统尚未初始化")
        
        return self.chatbot.get_relevant_documents(query, k)
    
    def clear_conversation(self):
        """清除对话历史"""
        if self.chatbot:
            self.chatbot.clear_memory()

# 使用示例
def main():
    """RAG 系统使用示例"""
    
    # 1. 配置 RAG 系统
    config = RAGConfig(
        dashscope_api_key="your-dashscope-api-key",
        llm_model_name="qwen-plus",
        embedding_model_name="text-embedding-v2",
        chunk_size=500,
        chunk_overlap=50,
        max_docs_per_query=4,
        vector_store_type="chroma",
        persist_directory="./rag_demo_db"
    )
    
    # 2. 初始化 RAG 系统
    rag_system = RAGSystem(config)
    
    # 3. 从文档初始化(首次运行)
    document_paths = [
        "./documents/",  # 文档目录
        "./knowledge_base.txt",  # 单个文件
        "./manual.pdf"  # PDF 文件
    ]
    
    try:
        # 如果是首次运行,从文档初始化
        rag_system.initialize_from_documents(
            document_paths, 
            retriever_type="compression"  # 使用压缩检索器
        )
        
        # 如果已有向量数据库,可以直接加载
        # rag_system.load_existing_system(retriever_type="multi_query")
        
    except Exception as e:
        logger.error(f"系统初始化失败: {e}")
        return
    
    # 4. 交互式对话
    print("\n=== RAG 智能问答系统 ===")
    print("输入 'quit' 退出,输入 'clear' 清除对话历史")
    
    while True:
        try:
            question = input("\n请输入您的问题: ").strip()
            
            if question.lower() == 'quit':
                break
            elif question.lower() == 'clear':
                rag_system.clear_conversation()
                print("对话历史已清除")
                continue
            elif not question:
                continue
            
            # 获取回答
            result = rag_system.chat(question)
            
            print(f"\n回答: {result['answer']}")
            
            # 显示参考文档信息
            if result['source_documents']:
                print(f"\n参考文档 ({len(result['source_documents'])} 个):")
                for i, doc in enumerate(result['source_documents'][:2]):
                    content = doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content
                    source = doc.metadata.get('source', '未知来源')
                    print(f"  [{i+1}] {source}: {content}")
            
        except KeyboardInterrupt:
            print("\n程序被用户中断")
            break
        except Exception as e:
            print(f"\n错误: {e}")
    
    print("\n谢谢使用!")

if __name__ == "__main__":
    main()

RAG 系统特性说明

1. 模块化设计
  • RAGConfig: 统一配置管理
  • DocumentProcessor: 文档加载和预处理
  • VectorStoreManager: 向量存储管理
  • RAGChatBot: 对话交互核心
  • RAGSystem: 系统整体封装
2. 多格式文档支持
  • 支持 PDF、TXT、Markdown、CSV 等格式
  • 智能文档分割,针对中文优化
  • 批量文档处理能力
3. 灵活的检索策略
  • 基础相似度检索
  • 多查询检索(生成多个相关查询)
  • 上下文压缩检索(提取关键信息)
4. 对话记忆管理
  • 窗口记忆机制
  • 对话历史跟踪
  • 上下文感知回答
5. 中文优化
  • 中文分词和文本分割
  • 中文提示模板
  • Qwen 模型深度集成
6. 生产级特性
  • 完整的错误处理
  • 日志记录和监控
  • 配置文件管理
  • 向量数据库持久化

快速开始

# 1. 安装依赖
# pip install langchain dashscope chromadb faiss-cpu

# 2. 配置 API Key
config = RAGConfig(
    dashscope_api_key="your-api-key-here"
)

# 3. 初始化系统
rag = RAGSystem(config)
rag.initialize_from_documents(["your_documents_path"])

# 4. 开始对话
result = rag.chat("这些文档的主要内容是什么?")
print(result['answer'])

这个完整的 RAG 实现提供了企业级的功能和性能,可以直接用于生产环境中的知识问答、文档分析等应用场景。