
本地调用模型
from langchain.llms.base import LLM
from transformers import AutoTokenizer, AutoModel
from langchain_core.messages.ai import AIMessage
from typing import Any, List, Dict
class ChatGLM3(LLM):
max_token: int = 8192
do_sample: bool = True
temperature: float = 0.3
top_p: float = 0.0
tokenizer: Any = None
model: Any = None
history: List[List[str]] = []
def __init__(self):
super().__init__()
@property
def _llm_type(self) -> str:
return "ChatGLM3"
def load_model(self, modelPath: str = None):
"""
加载本地ChatGLM3模型
:param modelPath: 本地模型路径,例如 "I:/wd0717/wendaMain/wenda/model/chatglm3-6b"
"""
if modelPath is None:
raise ValueError("必须指定本地模型路径 modelPath")
tokenizer = AutoTokenizer.from_pretrained(
modelPath,
trust_remote_code=True,
use_fast=True
)
model = AutoModel.from_pretrained(
modelPath,
trust_remote_code=True,
device_map="auto"
)
model = model.eval()
self.model = model
self.tokenizer = tokenizer
def _call(self, prompt: str, config: Dict = None, history: List[List[str]] = None) -> str:
if config is None:
config = {}
if history is None:
history = []
ai_message = self.invoke(prompt, config, history)
return ai_message.content
def invoke(self, prompt: str, config: Dict = None, history: List[List[str]] = None) -> AIMessage:
if config is None:
config = {}
if history is None:
history = []
if not isinstance(prompt, str):
prompt = prompt.to_string()
response, history = self.model.chat(
self.tokenizer,
prompt,
history=history,
do_sample=self.do_sample,
max_length=self.max_token,
temperature=self.temperature,
top_p=self.top_p
)
self.history = history
return AIMessage(content=response)
def stream(self, prompt: str, config: Dict = None, history: List[List[str]] = None):
if config is None:
config = {}
if history is None:
history = []
if not isinstance(prompt, str):
prompt = prompt.to_string()
preResponse = ""
for response, new_history in self.model.stream_chat(
self.tokenizer,
prompt,
history=history,
do_sample=self.do_sample,
temperature=self.temperature,
max_length=self.max_token
):
self.history = new_history
if preResponse == "":
result = response
else:
result = response[len(preResponse):]
preResponse = response
yield result
if __name__ == "__main__":
llm = ChatGLM3()
llm.load_model(modelPath="I:/wd0717/wendaMain/wenda/model/chatglm3-6b")
print("=== 普通调用 ===")
result = llm.invoke("你好,请介绍一下自己")
print(result.content)
print("\n=== 流式调用 ===")
for chunk in llm.stream("请用3句话介绍ChatGLM3的特点"):
print(chunk, end="", flush=True)