第 20 课: 自定义 Provider -- 从零实现一个 ChatModel

0 阅读6分钟

课程目标

动手实践 Provider 开发:理解 _generate() 的最小实现要求,掌握 FakeChatModel 系列测试工具,了解 @langchain/standard-tests 标准测试套件。


20.1 Provider 开发脚手架

LangChain.js 提供了官方脚手架工具:

npx create-langchain-integration my-provider

生成的项目骨架包含:

  • src/chat_models.ts -- ChatModel 实现模板
  • src/tests/ -- 单元测试和集成测试模板
  • tsdown.config.ts -- ESM + CJS 双输出构建配置
  • 标准测试的引用模板

20.2 最小实现:_generate()

实现一个自定义 ChatModel 只需三步:

20.2.1 继承 BaseChatModel

import { BaseChatModel, BaseChatModelCallOptions } from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { BaseMessage, AIMessage } from "@langchain/core/messages";
import { ChatResult } from "@langchain/core/outputs";

interface MyChatModelInput {
  apiKey?: string;
  model?: string;
  temperature?: number;
}

class ChatMyProvider extends BaseChatModel<BaseChatModelCallOptions> {
  model: string;
  temperature: number;

  constructor(fields: MyChatModelInput = {}) {
    super(fields);
    this.model = fields.model ?? "my-default-model";
    this.temperature = fields.temperature ?? 0.7;
  }

  _llmType(): string {
    return "my-provider";  // 用于日志和序列化标识
  }

  _combineLLMOutput(): undefined {
    return undefined;  // batch 时合并多次调用的元数据,简单情况返回 undefined
  }

  async _generate(
    messages: BaseMessage[],
    options?: this["ParsedCallOptions"],
    runManager?: CallbackManagerForLLMRun
  ): Promise<ChatResult> {
    // 1. 将 LangChain 消息转换为你的 API 格式
    const apiMessages = messages.map(m => ({
      role: m._getType(),
      content: typeof m.content === "string" ? m.content : JSON.stringify(m.content),
    }));

    // 2. 调用你的 API
    const response = await this._callMyAPI(apiMessages, options?.signal);

    // 3. 将响应转换回 LangChain 格式
    const text = response.text;
    await runManager?.handleLLMNewToken(text);  // 通知 callback 系统

    return {
      generations: [{
        message: new AIMessage(text),
        text,
      }],
    };
  }

  private async _callMyAPI(messages: any[], signal?: AbortSignal) {
    // 你的 API 调用逻辑
    return { text: "模拟响应" };
  }
}

关键点

  • _llmType() 返回 Provider 标识字符串
  • _generate() 是唯一必须实现的方法
  • 记得调用 runManager?.handleLLMNewToken() 以支持 callback 系统

20.2.2 可选增强:流式支持

async *_streamResponseChunks(
  messages: BaseMessage[],
  options: this["ParsedCallOptions"],
  runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
  const apiMessages = this._convertMessages(messages);
  const stream = await this._callMyStreamAPI(apiMessages);

  for await (const chunk of stream) {
    const text = chunk.text;
    await runManager?.handleLLMNewToken(text);
    yield new ChatGenerationChunk({
      message: new AIMessageChunk({ content: text }),
      text,
    });
  }
}

20.2.3 可选增强:bindTools

override bindTools(
  tools: StructuredTool[],
  kwargs?: Partial<BaseChatModelCallOptions>
): Runnable {
  const formattedTools = tools.map(tool => ({
    name: tool.name,
    description: tool.description,
    parameters: toJsonSchema(tool.schema),
  }));
  return this.withConfig({ tools: formattedTools, ...kwargs });
}

20.3 FakeChatModel 系列:测试利器

LangChain.js 提供了多种 Fake 模型用于测试,避免依赖真实 API。

20.3.1 FakeChatModel

源码位置: libs/langchain-core/src/utils/testing/chat_models.ts:56

export class FakeChatModel extends BaseChatModel {
  async _generate(
    messages: BaseMessage[],
    options?: this["ParsedCallOptions"],
    runManager?: CallbackManagerForLLMRun
  ): Promise<ChatResult> {
    // 将所有输入消息的内容拼接后返回
    const text = messages.map(m => {
      if (typeof m.content === "string") return m.content;
      return JSON.stringify(m.content, null, 2);
    }).join("\n");

    await runManager?.handleLLMNewToken(text);
    return { generations: [{ message: new AIMessage(text), text }] };
  }
}

用途:最简单的 mock,回显输入内容。适合测试链的组合逻辑而不关心模型输出。

20.3.2 FakeStreamingChatModel

源码位置: libs/langchain-core/src/utils/testing/chat_models.ts:101

export class FakeStreamingChatModel extends BaseChatModel {
  sleep = 50;            // 每个 chunk 之间的延迟
  responses: BaseMessage[] = [];  // 预设的完整响应
  chunks: AIMessageChunk[] = [];  // 预设的精确 chunk 序列
  toolStyle: "openai" | "anthropic" | "bedrock" | "google" = "openai";
  thrownErrorString?: string;     // 模拟错误

  bindTools(tools) {
    // 根据 toolStyle 格式化工具定义
    const toolDicts = tools.map(t => {
      switch (this.toolStyle) {
        case "openai":
          return { type: "function", function: { name: t.name, ... } };
        case "anthropic":
          return { name: t.name, input_schema: ... };
        // ...
      }
    });
    // ...
  }
}

用途:模拟流式输出,支持预设 chunk 序列和多种工具格式。

20.3.3 FakeBuiltModel (fakeModel)

源码位置: libs/langchain-core/src/testing/fake_model_builder.ts

import { fakeModel } from "@langchain/core/testing";

const model = fakeModel()
  .respond(new AIMessage("你好!"))
  .respond(new AIMessage("再见!"))
  .respondWithTools([{ name: "get_weather", args: { city: "北京" } }]);

// 第一次调用返回 "你好!"
const r1 = await model.invoke([new HumanMessage("hi")]);

// 第二次调用返回 "再见!"
const r2 = await model.invoke([new HumanMessage("bye")]);

// 第三次调用返回带工具调用的消息
const r3 = await model.invoke([new HumanMessage("天气")]);

// 访问调用记录
console.log(model.calls);     // 所有调用的 messages 和 options
console.log(model.callCount); // 3

用途:最灵活的测试工具,支持:

  • FIFO 响应队列(先进先出)
  • 动态响应工厂函数
  • 工具调用模拟
  • 错误模拟
  • 调用记录审计

20.4 自定义 Vitest 断言

源码位置: libs/langchain-core/src/testing/matchers.ts

LangChain.js 提供了 Vitest 扩展断言:

import { expect } from "vitest";
import { langchainCoreMatchers } from "@langchain/core/testing";

expect.extend(langchainCoreMatchers);

// 类型安全的消息断言
expect(result).toBeHumanMessage("hello");
expect(result).toBeAIMessage("world");
expect(result).toBeSystemMessage();
expect(result).toBeToolMessage();

20.5 标准测试套件

源码位置: internal/standard-tests/src/

标准测试确保所有 Provider 行为一致。

20.5.1 单元测试

// internal/standard-tests/src/unit_tests/chat_models.ts
export abstract class ChatModelUnitTests<CallOptions, OutputMessageType, ConstructorArgs>
  extends BaseChatModelsTests
{
  // 测试标准 LangSmith 参数
  expectedLsParams(): Partial<LangSmithParams> {
    return {
      ls_provider: "string",
      ls_model_name: "string",
      ls_model_type: "chat",
      ls_temperature: 0,
      ls_max_tokens: 0,
      ls_stop: ["Array<string>"],
    };
  }
  // ... 更多标准测试
}

20.5.2 使用标准测试

import { ChatModelUnitTests } from "@langchain/standard-tests/unit_tests/chat_models";
import { ChatMyProvider } from "../src/chat_models.js";

class MyChatModelUnitTests extends ChatModelUnitTests {
  constructor() {
    super({
      Cls: ChatMyProvider,
      constructorArgs: {
        model: "my-model",
        temperature: 0,
      },
    });
  }
}

const tests = new MyChatModelUnitTests();
tests.runTests();

20.5.3 测试分类

测试类型文件命名运行条件用途
单元测试*.test.ts无需 API key验证基本行为
集成测试*.int.test.ts需要真实 API key验证真实 API 交互
标准测试继承 standard-tests视类型而定确保 Provider 行为一致

20.6 完整实战:用 FakeListChatModel 模拟 Provider

import { describe, it, expect } from "vitest";
import { FakeStreamingChatModel } from "@langchain/core/utils/testing";
import { fakeModel } from "@langchain/core/testing";
import { HumanMessage, AIMessage, AIMessageChunk } from "@langchain/core/messages";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { StringOutputParser } from "@langchain/core/output_parsers";

describe("自定义链测试", () => {
  it("应该正确执行 prompt -> model -> parser 链", async () => {
    // 使用 fakeModel 模拟响应
    const model = fakeModel()
      .respond(new AIMessage("北京今天晴天,温度 25 度"));

    const prompt = ChatPromptTemplate.fromMessages([
      ["system", "你是一个天气助手"],
      ["human", "{question}"],
    ]);
    const parser = new StringOutputParser();

    const chain = prompt.pipe(model).pipe(parser);
    const result = await chain.invoke({ question: "北京天气怎么样?" });

    expect(result).toBe("北京今天晴天,温度 25 度");
    expect(model.callCount).toBe(1);
  });

  it("应该支持流式输出", async () => {
    const model = new FakeStreamingChatModel({
      responses: [new AIMessage("Hello World")],
      sleep: 10,
    });

    const chunks: string[] = [];
    for await (const chunk of await model.stream([new HumanMessage("hi")])) {
      if (typeof chunk.content === "string") {
        chunks.push(chunk.content);
      }
    }

    expect(chunks.join("")).toBe("Hello World");
  });

  it("应该支持工具调用模拟", async () => {
    const model = fakeModel()
      .respondWithTools([
        { name: "get_weather", args: { city: "北京" } },
      ]);

    const result = await model.invoke([new HumanMessage("天气")]);
    expect(result.tool_calls).toHaveLength(1);
    expect(result.tool_calls?.[0].name).toBe("get_weather");
  });

  it("应该支持错误模拟", async () => {
    const model = fakeModel()
      .respond(new Error("API rate limit exceeded"));

    await expect(
      model.invoke([new HumanMessage("hi")])
    ).rejects.toThrow("API rate limit exceeded");
  });
});

20.7 发布检查清单

如果你要将自定义 Provider 发布为 npm 包:

  • _generate() 正确实现,返回 ChatResult
  • _llmType() 返回唯一标识
  • _streamResponseChunks() 可选但推荐实现
  • bindTools() 在模型支持工具时实现
  • ESM + CJS 双输出配置(tsdown.config.ts
  • 单元测试通过(使用 FakeModel 系列)
  • 集成测试通过(使用真实 API)
  • 标准测试通过(继承 ChatModelUnitTests
  • lc_secrets 声明敏感字段
  • README 包含使用示例

20.8 源码精读路线

优先级文件关注点
P0langchain-core/src/utils/testing/chat_models.tsFakeChatModelFakeStreamingChatModel 实现
P0langchain-core/src/testing/fake_model_builder.tsFakeBuiltModelfakeModel() 构建器
P1langchain-core/src/testing/matchers.tsVitest 自定义断言
P1internal/standard-tests/src/unit_tests/chat_models.ts标准单元测试基类
P2internal/standard-tests/src/integration_tests/标准集成测试
P2任意简单 Provider 包参考完整实现

本课收获总结

级别你应该掌握的
🟢 基础理解 Provider 接入的最小实现要求:继承 BaseChatModel,实现 _generate()
🔵 中阶能用 fakeModel() 构建器和 FakeStreamingChatModel 编写无 API 依赖的测试
🟡 高阶实现完整的 _generate() + _streamResponseChunks() + bindTools()
🟠 资深用标准测试套件验证 Provider 实现的正确性和一致性
🔴 架构设计 Provider 的测试金字塔:FakeModel 单元测试 -> 真实 API 集成测试 -> 标准测试合规验证

下一课预告

第 21 课进入 Callbacks 系统——框架的"神经网络"。理解事件如何在 Runnable 链中传播,回调处理器如何观测执行过程。