大模型的Web界面的实现
- 本项目源码来自tigerbot的web_demo.py文件,基于源文件进行修改并添加注释,让基础薄弱的同学能够更加快速的部署自己的大模型,并且实现流式输出。
- 本代码部分注释使用openai进行分析
- 在使用过程中,可以将第二部分代码的第6行换成自己模型的路径
#导入所需要的库
import torch
import os
import sys
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM , TextIteratorStreamer ,GenerationConfig
import mdtex2html
from threading import Thread
import gc
#修改python的“sys.path”,以便python能够找到和导入某些模块或包
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
在第6行中,torch.cuda:这是PyTorch中用于处理与CUDA相关的操作模块。CUDA是NVIDIA开发的并行计算平台和应用编程接口,它允许使用GPU进行通用目的的计算
#关闭并行处理功能
os.environ["TOKENIZERS_PARALLELISM"] = "false"
#最大生成文字的长度
max_generate_length :int = 1024
#导入模型,如果有经过微调的权重需要加载,可以预先合并权重
model_path= "你的模型路径(本地或者线上的)"
#如果模型加载正确则输出
print(f"loading model:{model_path}... ")
#用于获取当前gpu设备的索引,如果有多个GPUpytorch为每个GPU分配一个索引,从0开始。
#将返回的设备索引存储到变量device中,可以使用这个索引进行后续的GPU的相关操作
device = torch.cuda.current_device()
#从预训练模型的加载路径加载模型及其生成配置,以便后续进行文本生成任务
generation_config=GenerationConfig.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path,torch_dtype=torch.float16,device_map='auto')
加载一个与预训练模型相对应的标记器(tokenizer)
#加载一个给定的预训练模型相对应的标记器,并对其进行配置,使其可以对文本进行适当的填充和截断。
tokenizer = AutoTokenizer.from_pretrained(
model_path,
cache_dir=None,
#标记器的最大长度。
model_max_length=max_generate_length,
#当文本小于“model_max_length”时,会在文本的左侧添加填充标记
padding_size="left",
#当文本大于“model_max_length”时,会在文本的左侧进行截断
truncation_size='left',
padding=True,
truncation=True
)
#确保标记器有一个合适的最大长度
if tokenizer.model_max_length is None or tokenizer.model_max_length > 1024:
tokenizer.model_max_length = 1024
#----Override Chatbot.postprocess----
#对每条输入的聊天记录进行检查,确保他们不是’None‘
#对每条非’None‘对消息和响应使用’mdtex2html.convert‘进行转换,将他们从Markdown和Tex混合格式转换为HTML格式
def postprocess(self,y):
#确保输入即使时“None”,函数也能正常工作
if y is None:
return []
for i,(message,response) in enumerate(y):
y[i]=(
None if message is None else mdtex2html.convert((message)),
None if message is None else mdtex2html.convert((response)),
)
return y
#当我们在‘gr.chat‘实例上调用’postprocess‘方法时,会执行上面的postprocess函数
gr.Chatbot.postprocess = postprocess
#该函数主要功能是处理和格式化文本,特别是将某种特定的文本格式转化为HTML格式
def parse_text(text):
#将输入的文本‘text’按换行符分割成单独的行
lines = text.split("\n")
#去除掉所有的空格
lines = [line for line in lines if line != ""]
#初始化一个计数器
count = 0
#使用for循环遍历每一行
for i, line in enumerate(lines):
#如果行中包含````(通常在markdown中用于表示代码块的开始或结束)
if "```" in line:
count += 1
#拆分该行以获取其中的内容
items = line.split('`')
#如果count是奇数,则将该行进行替换
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
#如果count是偶数,则将该行进行替换
else:
lines[i] = f'<br></code></pre>'
#对于其他行,如果他们位于代码块内(由if count % 2 ==1 判断),则对行中特殊字符进行转义
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
#在该行前添加标签
lines[i] = "<br>"+line
#将处理后的行连接起来
text = "".join(lines)
#返回处理后的文本
return text
#使用模型生成一个答案流(连续文本块)
def generate_stream(
#用户提供的输入文本
query,
#一个包含先前的对话历史记录的列表
history,
#输入给模型的最大长度
max_input_length,
#模型输出的文本的最大长度
max_output_length):
#定义两个标记tok_ins,tok_res,用于结构化输入文本,以便模型知道哪部分是指令(或问题)和哪部分是响应。
tok_ins = "\n\n### Instruction:\n"
tok_res = "\n\n### Respone:\n"
'''一个实例
### Instruction
How‘s the weather?
### Respone:
It's sunny'''
#格式化字符串,用它来构建模型的完整输入
prompt_input = tok_ins +"{instruction}"+tok_res
#初始化空字符串,用它来构建模型的完整输入,包括之前的历史记录和当前query
sess_text=""
#处理历史会话,使用对话的历史记录和当前的用户输入来构造模型的输入,然后使用tokenizer为模型准备这个输入
#如果历史会话非空,这部分代码会遍历每个历史条目,并将其添加到‘sess_text’中。这样模型机可以看到现的对话历史并在此基础上做出响应
if history:
for s in history:
sess_text += tok_ins + s["human"] + tok_res + s["assistant"]
#当前的‘query’被添加到history中,作为最新的对话条目,此时assistant为空,模型尚未生成
history.append({"human": query, "assistant": ""})
#当前的‘quert’被添加到sess_text中。这样模型的输入包含了整个对话历史,包括最新的用户输入。
sess_text += tok_ins + query.strip()
#使用之前定义的‘prompt_input’格式化字符串,它将sess_text插入到合适的位置,生成模型的完整输入
input_text = prompt_input.format_map({'instruction': sess_text.split(tok_ins, 1)[1]})
#将文本输入转化为模型可理解的形式
#return_tensors="pt"指示tokenizer返回pytorch张量
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_input_length)
inputs = {k: v.to(device) for k, v in inputs.items()}
#为实时或流式的文本生成提供了一个框架,允许在模型生成文本的同时处理和返回结果
streamer = TextIteratorStreamer(tokenizer,
skip_prompt=True,
skip_special_tokens=True,
spaces_between_special_tokens=False)
#从先前加载的generation_config中取得生成参数,并与输入和其他参数合并
generation_kwargs = generation_config.to_dict()
generation_kwargs.update(dict(inputs))
generation_kwargs['streamer'] = streamer
generation_kwargs['max_new_tokens'] = max_output_length
#使用python的‘Thread’类异步起动文本生成。这意味着model.generate将在后台运行,而主线程可以继续执行
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
#每次模型生成新文本,这个循环都会执行
#同时还更新历史记录,将在新的文本添加到最后一条历史记录中
answer = ""
for new_text in streamer:
if len(new_text) == 0:
continue
if new_text.endswith(tokenizer.eos_token):
new_text = new_text.rsplit(tokenizer.eos_token, 1)[0]
answer += new_text
history[-1]['assistant'] = answer
yield answer, history
#该函数的目的是为了生成模型实时的答复并更新模型的对话历史记录
def predict(input, chatbot, max_input_length, max_generate_length, history):
chatbot.append((parse_text(input), ""))
#生成流式响应
for response, history in generate_stream(
input,
history,
max_input_length=max_input_length,
max_output_length=max_generate_length,
):
if response is None:
break
#更新对话
chatbot[-1] = (parse_text(input), parse_text(response))
#返回结果
yield chatbot, history
定义两个清除函数
#重置用户的输入
def reset_user_input():
return gr.update(value='')
#重置聊天历史
def reset_state():
return [], []
创建一个聊天机器人的图形界面
#进入一个上下文管理器,意味着在这个块中的所有内容都被包括在‘demo’这个对象中
with gr.Blocks() as demo:
#标题
gr.HTML("""<h1 align="center">标题</h1>""")
#创建一个聊天机器人的组件
chatbot = gr.Chatbot()
with gr.Row():
#定义一个宽度为4的列
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
#调整输入和生成的最大长度
max_input_length = gr.Slider(0, 1024, value=512, step=1.0, label="Maximum input length", interactive=True)
max_generate_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum generate length", interactive=True)
history = gr.State([])
#当用户点击submitBtn时,会调用‘predict’函数并传入相对应的参数。它还会调用reset_user_input清除输入
submitBtn.click(predict, [user_input, chatbot, max_input_length, max_generate_length, history], [chatbot, history],
show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])
#当用户点击emptyBtn时,会调用reset_state函数并重置聊天机器人和历史状态
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
demo.queue().launch(share=True, inbrowser=True)
在最后一行,share=True时则能开启一个公共链接。
源码在tigerbot的项目中,在后期时会更新API的部署代码的注释。