掌握LLM数据管理:使用Argilla记录和优化模型输入输出

74 阅读2分钟

引言

在大型语言模型(LLM)时代,数据的有效管理和优化变得至关重要。Argilla作为一个开源的数据管理平台,为开发者提供了便捷的LLM反馈和监控工具。在本文中,我们将探讨如何使用ArgillaCallbackHandler跟踪LLM的输入输出,从而生成用于未来微调的数据集。

主要内容

安装和设置

首先,我们需要安装所需的软件包:

%pip install --upgrade --quiet langchain langchain-openai argilla

获取Argilla和OpenAI的API凭证:

  1. 访问Argilla UI,点击个人资料并进入“我的设置”,复制API密钥。
  2. 获取OpenAI API凭证,请访问 OpenAI API Keys

设置环境变量:

import os

os.environ["ARGILLA_API_URL"] = "..."  # 使用API代理服务提高访问稳定性
os.environ["ARGILLA_API_KEY"] = "..."
os.environ["OPENAI_API_KEY"] = "..."

Argilla设置

为了使用ArgillaCallbackHandler,我们需要创建一个新的FeedbackDataset来跟踪LLM实验:

import argilla as rg
from packaging.version import parse as parse_version

if parse_version(rg.__version__) < parse_version("1.8.0"):
    raise RuntimeError("`FeedbackDataset` is only available in Argilla v1.8.0 or higher.")

dataset = rg.FeedbackDataset(
    fields=[
        rg.TextField(name="prompt"),
        rg.TextField(name="response"),
    ],
    questions=[
        rg.RatingQuestion(name="response-rating", description="How would you rate the quality of the response?", values=[1, 2, 3, 4, 5], required=True),
        rg.TextQuestion(name="response-feedback", description="What feedback do you have for the response?", required=False),
    ],
    guidelines="You're asked to rate the quality of the response and provide feedback.",
)

rg.init(api_url=os.environ["ARGILLA_API_URL"], api_key=os.environ["ARGILLA_API_KEY"])
dataset.push_to_argilla("langchain-dataset")

跟踪

我们可以使用ArgillaCallbackHandler来跟踪LLM的输入输出:

from langchain_community.callbacks.argilla_callback import ArgillaCallbackHandler

argilla_callback = ArgillaCallbackHandler(
    dataset_name="langchain-dataset",
    api_url=os.environ["ARGILLA_API_URL"],  # 使用API代理服务提高访问稳定性
    api_key=os.environ["ARGILLA_API_KEY"],
)

代码示例

示例1:跟踪LLM单次运行

使用OpenAI LLM生成文本,并记录其输入输出。

from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_openai import OpenAI

callbacks = [StdOutCallbackHandler(), argilla_callback]
llm = OpenAI(temperature=0.9, callbacks=callbacks)
llm.generate(["Tell me a joke", "Tell me a poem"] * 3)

示例2:在链中跟踪LLM

创建一个链,记录初始提示和最终响应:

from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate

template = """You are a playwright. Given the title of play, it is your job to write a synopsis for that title. Title: {title} Playwright: This is a synopsis for the above play:"""
prompt_template = PromptTemplate(input_variables=["title"], template=template)
synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)

test_prompts = [{"title": "Documentary about Bigfoot in Paris"}]
synopsis_chain.apply(test_prompts)

常见问题和解决方案

  • 版本兼容性:确保安装的Argilla版本支持FeedbackDataset功能。
  • 网络限制:某些地区可能需要使用API代理服务以提高访问Argilla和OpenAI的稳定性。

总结和进一步学习资源

通过本指南,您已经了解如何使用Argilla有效地跟踪和优化LLM的输入输出。有关更多信息和高级用法,请参阅以下资源:

参考资料

  • Argilla 官方文档
  • Langchain 文档
  • OpenAI API 文档

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---