大家好,我是雨飞。最近,一直在研究 RAG 和大模型应用相关的内容,发现千问系列的大模型使用起来是比较舒服的,但还是存在一些小问题。今天就给大家分享两个使用小技巧。
千问已经出 1.5 版本的,官网是:Qwen
github 地址:github.com/QwenLM
优势
- 千问支持的模型大小很多,包括0.5B、1.8B、4B、7B、14B和72B。目前来说,能开源 72B 及以上大小的厂商还是比较少的。
- 支持工具调用、RAG(检索增强文本生成)、角色扮演、AI Agent等。这个能力,目前很多主流的大模型都是不支持的,只有看到智谱的 API 可以支持工具调用。
注意:1.5 的版本和千问的大模型是不兼容的,代码和权重都不兼容,因此没法混用。
问题
现在很多时候,我们都需要支持流式输出,而这个需要高版本的 transformers,4.37.0 以下的低版本是不支持的。另外,要实现 Agent 或者其他功能需要实现 stop_word_ids 的功能,需要自己开发。
在 github 的 issue 中看到了一个解法,就是通过配置StopWordsLogitsProcessor来实现对stop_words的支持。
为此,我汇总了官方和 github 的代码,实现了流式输出+stop_word_ids的功能,代码如下:
代码
from typing import Tuple, List, Union, Iterable
import copy
import torch
import time
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation import LogitsProcessor
from transformers import TextIteratorStreamer
from threading import Thread
class StopWordsLogitsProcessor(LogitsProcessor):
"""
:class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
Args:
stop_words_ids (:obj:`List[List[int]]`):
List of list of token ids of stop ids. In order to get the tokens of the words
that should not appear in the generated text, use :obj:`tokenizer(bad_word,
add_prefix_space=True).input_ids`.
eos_token_id (:obj:`int`):
The id of the `end-of-sequence` token.
"""
def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
raise ValueError(
f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
)
if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
raise ValueError(
f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
)
if any(
any(
(not isinstance(token_id, (int, np.integer)) or token_id < 0)
for token_id in stop_word_ids
)
for stop_word_ids in stop_words_ids
):
raise ValueError(
f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
)
self.stop_words_ids = list(
filter(
lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
)
)
self.eos_token_id = eos_token_id
for stop_token_seq in self.stop_words_ids:
assert (
len(stop_token_seq) > 0
), "Stop words token sequences {} cannot have an empty list".format(
stop_words_ids
)
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
stopped_samples = self._calc_stopped_samples(input_ids)
for i, should_stop in enumerate(stopped_samples):
if should_stop:
scores[i, self.eos_token_id] = float(2 ** 15)
return scores
def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
elif len(tokens) > len(prev_tokens):
# if bad word tokens are longer then prev input_ids they can't be equal
return False
elif prev_tokens[-len(tokens):].tolist() == tokens:
# if tokens match
return True
else:
return False
def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
stopped_samples = []
for prev_input_ids_slice in prev_input_ids:
match = False
for stop_token_seq in self.stop_words_ids:
if self._tokens_match(prev_input_ids_slice, stop_token_seq):
# if tokens do not match continue
match = True
break
stopped_samples.append(match)
return stopped_samples
checkpoint_path = '/app/Qwen1.5-0.5B-Chat/'
tokenizer = AutoTokenizer.from_pretrained(
checkpoint_path,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
device_map="cuda:0",
torch_dtype=torch.float16,
pad_token_id=151645,
trust_remote_code=True,
).eval()
model.generation_config = GenerationConfig.from_pretrained(
checkpoint_path,
pad_token_id=151645,
trust_remote_code=True,
)
stop_words = ["助手"]
if stop_words is not None:
print(f"stop_words : {stop_words}")
stop_words_ids = [tokenizer.encode(_) for _ in stop_words]
print(f"stop_words_ids : {stop_words_ids}")
stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids,
eos_token_id=model.generation_config.eos_token_id,
)
logits_processor = LogitsProcessorList([stop_words_logits_processor])
else:
logits_processor = None
prompt = "你是谁,请用20个字说明"
messages = [
{"role": "system", "content": "你是商家小助手"},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=1024)
generation_kwargs["logits_processor"] = logits_processor
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
print("text: {}".format(generated_text))
print(generated_text)
好了,我写完了,有帮助欢迎点赞收藏评论一键三连呀。