Langchain-chatchat系列四: Langchain源码debug

1,549 阅读5分钟

Langchain-chatchat系列四: Langchain源码debug

一、代码结构

  • configs/ 配置文件路径
  • server/ api服务、大模型服务等服务程序等核心代码
  • webui_pages/ webui服务
  • startup.py 启动脚本

img

zhuanlan.zhihu.com/p/655579960

│  .gitignore
│  CONTRIBUTING.md
│  init_database.py  用于初始化知识库
│  LICENSE
│  README.md
│  release.py
│  requirements.txt
│  requirements_api.txt
│  requirements_webui.txt
│  shutdown_all.sh 一键停止脚本,kill掉启动的服务
│  startup.py 一键启动
│  webui.py ui界面启动
├─chains
│      llmchain_with_history.py
│      
├─common
│      __init__.py
│      
├─configs
│      model_config.py.example 模型配置文件,配置使用的LLM和Emebdding模型;
│      server_config.py.example
│      __init__.py
├─embeddings
│      __init__.py
├─knowledge_base
│  └─samples  知识库
│      ├─content
│      │      test.txt 知识库上传的文档
│      └─vector_store 向量化后的知识
│              index.faiss 
│              index.pkl
│              
├─nltk_data Natural Language Toolkit (NLTK)是一个广泛使用的Python自然语言处理工具库           
├─server
│  │  api.py 用于启动API服务
│  │  api_allinone_stale.py
│  │  llm_api.py 用于启动LLM
│  │  llm_api_shutdown.py  
│  │  llm_api_stale.py
│  │  utils.py
│  │  webui_allinone_stale.py
│  │  
│  ├─chat 
│  │      chat.py 用于与LLM模型对话
│  │      knowledge_base_chat.py 用于与知识库对话
│  │      openai_chat.py
│  │      search_engine_chat.py 用于搜索引擎对话
│  │      utils.py
│  │      __init__.py
│  │      
│  ├─db 知识库的数据库
│  │  │  base.py
│  │  │  session.py
│  │  │  __init__.py
│  │  │  
│  │  ├─models
│  │  │      base.py 数据库表的基础属性
│  │  │      knowledge_base_model.py 知识库模型的表字段
│  │  │      knowledge_file_model.py 知识库文件的表字段
│  │  │      __init__.py
│  │  │      
│  │  └─repository
│  │          knowledge_base_repository.py
│  │          knowledge_file_repository.py
│  │          __init__.py
│  │          
│  ├─knowledge_base
│  │  │  kb_api.py 知识库API,创建、删除知识库;
│  │  │  kb_doc_api.py 知识库文件API,搜索、删除、更新、上传文档,重建向量库;
│  │  │  migrate.py 初始化 or 迁移重建知识库;
│  │  │  utils.py 提供了加载Embedding、获取文件加载器、文件转text的函数,可设置文本分割器;
│  │  │  __init__.py
│  │  │  
│  │  └─kb_service
│  │          base.py 向量库的抽象类
│  │          default_kb_service.py
│  │          faiss_kb_service.py faiss向量库子类
│  │          milvus_kb_service.py
│  │          pg_kb_service.py
│  │          __init__.py
│  │          
│  └─static
│          
├─tests
│  └─api
│          test_kb_api.py 测试知识库API
│          test_stream_chat_api.py 测试对话API
│          
├─text_splitter 各种文本分割器
│      ali_text_splitter.py 达摩院的文档分割
│      chinese_text_splitter.py 中文文本分割
│      zh_title_enhance.py 中文标题增强:判断是否是标题,然后在下一段文字的开头加入提示语与标题建立关联。
│      __init__.py
│      
└─webui_pages UI界面构建
    │  utils.py 简化api调用
    │  __init__.py
    │  
    ├─dialogue
    │      dialogue.py 问答功能,LLM问答和知识库问答,还有搜索引擎问答;
    │      __init__.py
    │      
    ├─knowledge_base 
    │      knowledge_base.py 知识库管理界面构建
    │      __init__.py
    │      
    └─model_config
            model_config.py 模型配置页面,TODO。应该是可以在界面上手动选择采用哪个LLM和Embedding模型。
            __init__.py
            

二、页面对话框调用链路

页面输入问题是,调用的接口为:

   app.post("/chat/chat",
             tags=["Chat"],
             summary="与llm模型对话(通过LLMChain)",
             )(chat)

对应的chat 方法是:

​
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
               conversation_id: str = Body("", description="对话框ID"),
               history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
               history: Union[int, List[History]] = Body([],
                                                         description="历史对话,设为一个整数可以从数据库中读取历史消息",
                                                         examples=[[
                                                             {"role": "user",
                                                              "content": "我们来玩成语接龙,我先来,生龙活虎"},
                                                             {"role": "assistant", "content": "虎头虎脑"}]]
                                                         ),
               stream: bool = Body(False, description="流式输出"),
               model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
               temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
               max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
               # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
               prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
               ):
    async def chat_iterator() -> AsyncIterable[str]:
        nonlocal history, max_tokens
        callback = AsyncIteratorCallbackHandler()
        callbacks = [callback]
        memory = None
​
        if conversation_id:
            message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
            # 负责保存llm response到message db
            conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
                                                                chat_type="llm_chat",
                                                                query=query)
            callbacks.append(conversation_callback)
​
        if isinstance(max_tokens, int) and max_tokens <= 0:
            max_tokens = None
​
        model = get_ChatOpenAI(
            model_name=model_name,
            temperature=temperature,
            max_tokens=max_tokens,
            callbacks=callbacks,
        )
​
        if history: # 优先使用前端传入的历史消息
            history = [History.from_data(h) for h in history]
            prompt_template = get_prompt_template("llm_chat", prompt_name)
            input_msg = History(role="user", content=prompt_template).to_msg_template(False)
            chat_prompt = ChatPromptTemplate.from_messages(
                [i.to_msg_template() for i in history] + [input_msg])
        elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息
            # 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量
            prompt = get_prompt_template("llm_chat", "with_history")
            chat_prompt = PromptTemplate.from_template(prompt)
            # 根据conversation_id 获取message 列表进而拼凑 memory
            memory = ConversationBufferDBMemory(conversation_id=conversation_id,
                                                llm=model,
                                                message_limit=history_len)
        else:
            prompt_template = get_prompt_template("llm_chat", prompt_name)
            input_msg = History(role="user", content=prompt_template).to_msg_template(False)
            chat_prompt = ChatPromptTemplate.from_messages([input_msg])
​
        chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)
​
        # Begin a task that runs in the background.
        task = asyncio.create_task(wrap_done(
            chain.acall({"input": query}),
            callback.done),
        )
​
        if stream:
            async for token in callback.aiter():
                # Use server-sent-events to stream the response
                yield json.dumps(
                    {"text": token, "message_id": message_id},
                    ensure_ascii=False)
        else:
            answer = ""
            async for token in callback.aiter():
                answer += token
            yield json.dumps(
                {"text": answer, "message_id": message_id},
                ensure_ascii=False)
​
        await task
​
    return StreamingResponse(chat_iterator(), media_type="text/event-stream")
​

调用链路是: 1.chat接口 2.chat_iterator() 3.StreamingResponse

chat.PNG

使用postman模拟用户页面问答的过程:

postman-chat.PNG

chat方法中调用了add_message_to_db方法:

@with_session
def add_message_to_db(session, conversation_id: str, chat_type, query, response="", message_id=None,
                      metadata: Dict = {}):
    """
    新增聊天记录
    """
    if not message_id:
        message_id = uuid.uuid4().hex
    m = MessageModel(id=message_id, chat_type=chat_type, query=query, response=response,
                     conversation_id=conversation_id,
                     meta_data=metadata)
    session.add(m)
    session.commit()
    return m.id

kb_config.py 数据库保存配置如下:

# 知识库默认存储路径
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
if not os.path.exists(KB_ROOT_PATH):
    os.mkdir(KB_ROOT_PATH)
# 数据库默认存储路径。
# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"

查看sqlite的数据表和结构:

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
​
# 创建数据库连接
engine = create_engine('sqlite:///info.db')
​
# 创建Session类,用于与数据库进行交互
Session = sessionmaker(bind=engine)
​
# 创建Session实例
session = Session()
​
try:
    # 开始事务
    with session.begin():
        # 执行SQL语句
        tables = session.execute('SELECT name FROM sqlite_master WHERE type="table"')
        # 打印所有表名
        for table in tables:
            print(table[0])
            
        messages = session.execute("SELECT * FROM message limit 2")
        # 处理查询结果
        for row in messages:
            print(row)
​
    # 提交事务
    session.commit()
except:
    # 回滚事务
    session.rollback()
    raise
finally:
    # 关闭连接
    session.close()
​
##原文链接:https://blog.csdn.net/weixin_62650212/article/details/130212100#四张表: conversation  message   knowledge_base   knowledge_file   file_doc  summary_chunk

SQLite对应的实体类定义:server.db.models 目录下:

class MessageModel(Base):
    """
    聊天记录模型
    """
    __tablename__ = 'message'
    id = Column(String(32), primary_key=True, comment='聊天记录ID')
    conversation_id = Column(String(32), default=None, index=True, comment='对话框ID')
    # chat/agent_chat等
    chat_type = Column(String(50), comment='聊天类型')
    query = Column(String(4096), comment='用户问题')
    response = Column(String(4096), comment='模型回答')
    # 记录知识库id等,以便后续扩展
    meta_data = Column(JSON, default={})
    # 满分100 越高表示评价越好
    feedback_score = Column(Integer, default=-1, comment='用户评分')
    feedback_reason = Column(String(255), default="", comment='用户评分理由')
    create_time = Column(DateTime, default=func.now(), comment='创建时间')
​
    def __repr__(self):
        return f"<message(id='{self.id}', conversation_id='{self.conversation_id}', chat_type='{self.chat_type}', query='{self.query}', response='{self.response}',meta_data='{self.meta_data}',feedback_score='{self.feedback_score}',feedback_reason='{self.feedback_reason}', create_time='{self.create_time}')>"
​

参考资料:

blog.csdn.net/ssw_1990/ar…

分词: zhuanlan.zhihu.com/p/638929185

blog.csdn.net/hy592070616…

SQLAlchemy

zhuanlan.zhihu.com/p/91169446

blog.csdn.net/weixin_6265…