Langchain本地调用模型

30 阅读2分钟

825b7ae59dbcb86774184385366df39e.jpg

本地调用模型

from langchain.llms.base import LLM
from transformers import AutoTokenizer, AutoModel
from langchain_core.messages.ai import AIMessage
from typing import Any, List, Dict

class ChatGLM3(LLM):
    # 模型参数初始化
    max_token: int = 8192
    do_sample: bool = True
    temperature: float = 0.3
    top_p: float = 0.0  # 补充缺失的冒号
    tokenizer: Any = None
    model: Any = None
    history: List[List[str]] = []  # 规范历史记录类型

    # 修正构造方法(原代码拼写错误+缩进错误)
    def __init__(self):
        super().__init__()

    # 修正_llm_type属性(重复定义+字符串引号错误+缩进错误)
    @property
    def _llm_type(self) -> str:
        return "ChatGLM3"  # 补充字符串引号,修正返回值

    # 修正模型加载方法(重复定义+参数拼写错误+路径转义+赋值符号错误)
    def load_model(self, modelPath: str = None):
        """
        加载本地ChatGLM3模型
        :param modelPath: 本地模型路径,例如 "I:/wd0717/wendaMain/wenda/model/chatglm3-6b"
        """
        if modelPath is None:
            raise ValueError("必须指定本地模型路径 modelPath")
        
        # 修正赋值符号:trust_remote_code-True -> trust_remote_code=True
        # 修正方法名拼写:from pretrained -> from_pretrained
        # 修正参数名:trust remote_code -> trust_remote_code
        tokenizer = AutoTokenizer.from_pretrained(
            modelPath,
            trust_remote_code=True,
            use_fast=True
        )
        
        # 修正变量名:mode1Path -> modelPath
        # 修正参数名:trust remote_code -> trust_remote_code
        model = AutoModel.from_pretrained(
            modelPath,
            trust_remote_code=True,
            device_map="auto"
        )
        model = model.eval()  # 设置模型为评估模式
        
        self.model = model
        self.tokenizer = tokenizer

    # 修正_call方法(拼写错误_cal1、返回值拼写returrn、参数错误)
    def _call(self, prompt: str, config: Dict = None, history: List[List[str]] = None) -> str:
        if config is None:
            config = {}
        if history is None:
            history = []
        ai_message = self.invoke(prompt, config, history)
        return ai_message.content  # _call方法需返回字符串,而非AIMessage对象

    # 修正invoke方法(缩进错误、参数缺失、语法错误)
    def invoke(self, prompt: str, config: Dict = None, history: List[List[str]] = None) -> AIMessage:
        if config is None:
            config = {}
        if history is None:
            history = []
        
        # 处理prompt格式
        if not isinstance(prompt, str):
            prompt = prompt.to_string()
        
        # 调用模型chat方法,补充缺失的参数括号和top_p参数
        response, history = self.model.chat(
            self.tokenizer,
            prompt,
            history=history,
            do_sample=self.do_sample,
            max_length=self.max_token,
            temperature=self.temperature,
            top_p=self.top_p  # 补充top_p参数
        )
        self.history = history
        return AIMessage(content=response)

    # 修正stream方法(缩进错误、语法错误、参数错误、变量名错误)
    def stream(self, prompt: str, config: Dict = None, history: List[List[str]] = None):
        if config is None:
            config = {}
        if history is None:
            history = []
        
        # 处理prompt格式
        if not isinstance(prompt, str):
            prompt = prompt.to_string()
        
        preResponse = ""
        # 补充stream_chat的完整参数,修正缩进
        for response, new_history in self.model.stream_chat(
            self.tokenizer,
            prompt,
            history=history,
            do_sample=self.do_sample,
            temperature=self.temperature,
            max_length=self.max_token
        ):
            self.history = new_history  # 同步更新历史记录
            if preResponse == "":
                result = response
            else:
                result = response[len(preResponse):]
            preResponse = response
            yield result

# 测试代码
if __name__ == "__main__":
    # 初始化模型
    llm = ChatGLM3()
    # 加载本地模型(替换为你的实际模型路径)
    llm.load_model(modelPath="I:/wd0717/wendaMain/wenda/model/chatglm3-6b")
    
    # 测试普通调用
    print("=== 普通调用 ===")
    result = llm.invoke("你好,请介绍一下自己")
    print(result.content)
    
    # 测试流式调用
    print("\n=== 流式调用 ===")
    for chunk in llm.stream("请用3句话介绍ChatGLM3的特点"):
        print(chunk, end="", flush=True)