**有效格式输出的终极指南:LM Format Enforcer**

207 阅读3分钟
# 有效格式输出的终极指南:LM Format Enforcer

在人工智能驱动的应用中,生成的文本往往需要符合特定的格式。对于开发者来说,确保输出与所需格式一致是一个不小的挑战。本文将介绍一个名为LM Format Enforcer的库,该库能够强制执行语言模型的输出格式,通过过滤无效的token来保持格式的正确性。

## 引言

LM Format Enforcer是一个实验性的库,通过结合字符级解析器和tokenizer前缀树来限制语言模型的输出格式,确保生成文本符合预定义的格式。这特别适用于需要生成结构化数据的场景,比如JSON格式的API调用等。

## 主要内容

### 设置模型

首先,我们需要设置一个Llama2模型,并初始化我们所需的输出格式。请注意,使用Llama2模型需要获得授权访问。

```python
import logging
from langchain_experimental.pydantic_v1 import BaseModel
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

logging.basicConfig(level=logging.ERROR)

class PlayerInformation(BaseModel):
    first_name: str
    last_name: str
    num_seasons_in_nba: int
    year_of_birth: int

model_id = "meta-llama/Llama-2-7b-chat-hf"
device = "cuda"

if torch.cuda.is_available():
    config = AutoConfig.from_pretrained(model_id)
    config.pretraining_tp = 1
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        config=config,
        torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto",
    )
else:
    raise Exception("GPU not available")
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

使用HuggingFace Pipeline建立基线

在应用LM Format Enforcer之前,我们可以通过HuggingFace Pipeline来生成一个基线输出。

from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline

hf_model = pipeline(
    "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200
)

original_model = HuggingFacePipeline(pipeline=hf_model)

def make_instruction_prompt(message):
    return f"[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>> {message} [/INST]"

def get_prompt(player_name):
    return make_instruction_prompt(
        prompt.format(
            player_name=player_name, arg_schema=PlayerInformation.schema_json()
        )
    )

generated = original_model.predict(get_prompt("Michael Jordan"))
print(generated)

应用JSONFormer LLM Wrapper

通过LMFormatEnforcer,我们可以增强文本输出,使其符合特定JSON格式的要求。

from langchain_experimental.llms import LMFormatEnforcer

lm_format_enforcer = LMFormatEnforcer(
    json_schema=PlayerInformation.schema(), pipeline=hf_model
)
results = lm_format_enforcer.predict(get_prompt("Michael Jordan"))
print(results)

适用性广泛,无需担心解析错误。

批量处理

LMFormatEnforcer支持批量处理,能够同时生成多个符合格式的输出。

prompts = [
    get_prompt(name) for name in ["Michael Jordan", "Kareem Abdul Jabbar", "Tim Duncan"]
]
results = lm_format_enforcer.generate(prompts)
for generation in results.generations:
    print(generation[0].text)

正则表达式模式

LMFormatEnforcer还支持通过正则表达式来过滤输出,这对于特定格式的文本生成非常有帮助。

question_prompt = "When was Michael Jordan Born? Please answer in mm/dd/yyyy format."
date_regex = r"(0?[1-9]|1[0-2])\/(0?[1-9]|1\d|2\d|3[01])\/(19|20)\d{2}"
answer_regex = " In mm/dd/yyyy format, Michael Jordan was born in " + date_regex

lm_format_enforcer = LMFormatEnforcer(regex=answer_regex, pipeline=hf_model)

full_prompt = make_instruction_prompt(question_prompt)
print("Enforced Output:")
print(lm_format_enforcer.predict(full_prompt))

常见问题和解决方案

  • 如何确保输出格式始终正确? 使用LMFormatEnforcer定义明确的JSON Schema或正则表达式来限制输出。
  • 如果模型生成与格式不符的内容怎么办? 确保正确地设置了LMFormatEnforcer的参数以过滤无效的token。

总结和进一步学习资源

LM Format Enforcer提供了一种高效的方法来确保语言模型的输出格式正确。在需要生成结构化数据的场景中,它是一个非常有用的工具。开发者还可以参考HuggingFace文档Langchain文档进行进一步学习。

参考资料

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---