千问1.5大模型使用小技巧

567 阅读3分钟

大家好,我是雨飞。最近,一直在研究 RAG 和大模型应用相关的内容,发现千问系列的大模型使用起来是比较舒服的,但还是存在一些小问题。今天就给大家分享两个使用小技巧。

千问已经出 1.5 版本的,官网是:Qwen

github 地址:github.com/QwenLM

优势

  • 千问支持的模型大小很多,包括0.5B、1.8B、4B、7B、14B和72B。目前来说,能开源 72B 及以上大小的厂商还是比较少的。
  • 支持工具调用、RAG(检索增强文本生成)、角色扮演、AI Agent等。这个能力,目前很多主流的大模型都是不支持的,只有看到智谱的 API 可以支持工具调用。

注意:1.5 的版本和千问的大模型是不兼容的,代码和权重都不兼容,因此没法混用。

问题

现在很多时候,我们都需要支持流式输出,而这个需要高版本的 transformers,4.37.0 以下的低版本是不支持的。另外,要实现 Agent 或者其他功能需要实现 stop_word_ids 的功能,需要自己开发。

在 github 的 issue 中看到了一个解法,就是通过配置StopWordsLogitsProcessor来实现对stop_words的支持。

为此,我汇总了官方和 github 的代码,实现了流式输出+stop_word_ids的功能,代码如下:

代码

from typing import Tuple, List, Union, Iterable
import copy
import torch
import time
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation import LogitsProcessor
from transformers import TextIteratorStreamer
from threading import Thread


class StopWordsLogitsProcessor(LogitsProcessor):
    """
    :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
    Args:
        stop_words_ids (:obj:`List[List[int]]`):
            List of list of token ids of stop ids. In order to get the tokens of the words
            that should not appear in the generated text, use :obj:`tokenizer(bad_word,
            add_prefix_space=True).input_ids`.
        eos_token_id (:obj:`int`):
            The id of the `end-of-sequence` token.
    """

    def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):

        if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
            raise ValueError(
                f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
            )
        if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
            raise ValueError(
                f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
            )
        if any(
                any(
                    (not isinstance(token_id, (int, np.integer)) or token_id < 0)
                    for token_id in stop_word_ids
                )
                for stop_word_ids in stop_words_ids
        ):
            raise ValueError(
                f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
            )

        self.stop_words_ids = list(
            filter(
                lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
            )
        )
        self.eos_token_id = eos_token_id
        for stop_token_seq in self.stop_words_ids:
            assert (
                    len(stop_token_seq) > 0
            ), "Stop words token sequences {} cannot have an empty list".format(
                stop_words_ids
            )

    def __call__(
            self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        stopped_samples = self._calc_stopped_samples(input_ids)
        for i, should_stop in enumerate(stopped_samples):
            if should_stop:
                scores[i, self.eos_token_id] = float(2 ** 15)
        return scores

    def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
        if len(tokens) == 0:
            # if bad word tokens is just one token always ban it
            return True
        elif len(tokens) > len(prev_tokens):
            # if bad word tokens are longer then prev input_ids they can't be equal
            return False
        elif prev_tokens[-len(tokens):].tolist() == tokens:
            # if tokens match
            return True
        else:
            return False

    def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
        stopped_samples = []
        for prev_input_ids_slice in prev_input_ids:
            match = False
            for stop_token_seq in self.stop_words_ids:
                if self._tokens_match(prev_input_ids_slice, stop_token_seq):
                    # if tokens do not match continue
                    match = True
                    break
            stopped_samples.append(match)

        return stopped_samples


checkpoint_path = '/app/Qwen1.5-0.5B-Chat/'
tokenizer = AutoTokenizer.from_pretrained(
    checkpoint_path,
    trust_remote_code=True,
)

model = AutoModelForCausalLM.from_pretrained(
    checkpoint_path,
    device_map="cuda:0",
    torch_dtype=torch.float16,
    pad_token_id=151645,
    trust_remote_code=True,
).eval()

model.generation_config = GenerationConfig.from_pretrained(
    checkpoint_path,
    pad_token_id=151645,
    trust_remote_code=True,
)
stop_words = ["助手"]
if stop_words is not None:
    print(f"stop_words : {stop_words}")
    stop_words_ids = [tokenizer.encode(_) for _ in stop_words]
    print(f"stop_words_ids : {stop_words_ids}")

    stop_words_logits_processor = StopWordsLogitsProcessor(
        stop_words_ids=stop_words_ids,
        eos_token_id=model.generation_config.eos_token_id,
    )
    logits_processor = LogitsProcessorList([stop_words_logits_processor])
else:
    logits_processor = None

prompt = "你是谁,请用20个字说明"
messages = [
    {"role": "system", "content": "你是商家小助手"},
    {"role": "user", "content": prompt}
]

text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=1024)
generation_kwargs["logits_processor"] = logits_processor

thread = Thread(target=model.generate, kwargs=generation_kwargs)

thread.start()
generated_text = ""
for new_text in streamer:
    generated_text += new_text
    print("text: {}".format(generated_text))

print(generated_text)

好了,我写完了,有帮助欢迎点赞收藏评论一键三连呀。