【Dify(v1.x) 核心源码深入解析】model_runtime 模块

9 阅读12分钟

重磅推荐专栏: 《大模型AIGC》 《课程大纲》 《知识星球》

本专栏致力于探索和讨论当今最前沿的技术趋势和应用领域,包括但不限于ChatGPT和Stable Diffusion等。我们将深入研究大型模型的开发和应用,以及与之相关的人工智能生成内容(AIGC)技术。通过深入的技术解析和实践经验分享,旨在帮助读者更好地理解和应用这些领域的最新进展

引言

在人工智能应用开发的浪潮中,模型管理与调用的效率和灵活性成为了决定项目成败的关键因素之一。Dify 的 model_runtime 模块以其清晰的架构和强大的功能,为开发者提供了一个高效、灵活且易于扩展的模型管理与调用解决方案。本文将深入解析 Dify 的 model_runtime 模块,从其核心架构到实际应用场景,结合丰富的代码示例,帮助读者全面理解这一模块的运作原理和使用方法。

一、模块概览

1.1 模块定位与作用

Dify 的 model_runtime 模块是连接模型供应商与业务逻辑的桥梁。它向上层应用提供统一的模型调用接口,向下兼容多种模型供应商的凭据验证和模型调用逻辑。通过解耦模型调用与业务逻辑,model_runtime 模块使得开发者能够轻松地横向扩展支持的模型类型和供应商,同时在前端实现无侵入式的动态展示。

1.2 核心支持功能

model_runtime 模块支持以下五种模型类型的能力调用:

  • LLM(Large Language Model):支持文本补全、对话以及预计算 tokens 能力。
  • Text Embedding Model:支持文本 Embedding 和预计算 tokens 能力。
  • Rerank Model:支持分段 Rerank 能力。
  • Speech-to-text Model:支持语音转文本能力。
  • Text-to-speech Model:支持文本转语音能力。
  • Moderation:支持内容审核能力。

二、模块架构详解

2.1 三层架构设计

Model Runtime 采用分层架构设计,从上到下分为工厂方法层、供应商层和模型层。这种分层设计不仅提高了代码的可维护性和可扩展性,还使得各层职责分明,便于开发者理解和使用。

2.1.1 工厂方法层

工厂方法层是模块的入口,提供以下核心功能:

  • 获取所有支持的供应商列表。
  • 获取所有可用的模型列表。
  • 获取特定供应商的实例。
  • 执行供应商和模型的凭据鉴权。
graph TD
    A[应用层] --> B[工厂方法层]
    B --> C[获取供应商列表]
    B --> D[获取模型列表]
    B --> E[获取供应商实例]
    B --> F[凭据鉴权]

2.1.2 供应商层

供应商层负责管理供应商相关的信息和逻辑,包括:

  • 获取当前供应商支持的模型列表。
  • 获取模型实例。
  • 执行供应商凭据鉴权。
  • 提供供应商配置规则信息。
graph TD
    B[工厂方法层] --> C[供应商层]
    C --> D[获取模型列表]
    C --> E[获取模型实例]
    C --> F[凭据鉴权]
    C --> G[配置规则信息]

对于凭据的处理,供应商层支持两种情况:

  1. 中心化供应商:如 OpenAI,需要定义如 api_key 这类的鉴权凭据。
  2. 本地部署供应商:如 Xinference,需要定义如 server_url 这类的地址凭据,有时还需要定义 model_uid 这类的模型类型凭据。

2.1.3 模型层

模型层是模块的核心,负责具体的模型调用逻辑,包括:

  • 各种模型类型的直接调用。
  • 预定义模型配置信息管理。
  • 获取预定义或远程模型列表。
  • 执行模型凭据鉴权。
graph TD
    C[供应商层] --> D[模型层]
    D --> E[模型调用]
    D --> F[配置信息管理]
    D --> G[模型列表获取]
    D --> H[凭据鉴权]

在模型层中,需要区分模型参数和模型凭据:

  • 模型参数:如 LLM 的 max_tokens、temperature 等,这些参数经常变动,由用户在前端调整,后端定义规则以便前端展示。
  • 模型凭据:如 api_key、server_url 等,这些凭据不常变动,配置后直接传递给模型层使用。

2.2 类图展示

以下是 model_runtime 模块的核心类图,展示了主要类之间的继承关系和关联关系:

classDiagram
    class Application
    class FactoryLayer
    class ProviderLayer
    class ModelLayer
    class BaseModel
    class LLMModel
    class TextEmbeddingModel
    class RerankModel
    class SpeechToTextModel
    class TextToSpeechModel
    class ModerationModel

    Application --> FactoryLayer
    FactoryLayer --> ProviderLayer
    ProviderLayer --> ModelLayer
    ModelLayer --> BaseModel
    BaseModel <|-- LLMModel
    BaseModel <|-- TextEmbeddingModel
    BaseModel <|-- RerankModel
    BaseModel <|-- SpeechToTextModel
    BaseModel <|-- TextToSpeechModel
    BaseModel <|-- ModerationModel

三、核心功能解析

3.1 凭据验证机制

凭据验证是确保模型调用安全性和合法性的关键步骤。model_runtime 模块提供了两 级凭据验证:供应商凭据验证和模型凭据验证。

3.1.1 供应商凭据验证

供应商凭据验证通过 validate_provider_credentials 方法实现。该方法根据供应商的配置规则对凭据进行校验。以下是一个简化的验证流程:

graph TD
    A[开始] --> B[获取凭据配置规则]
    B --> C{凭据是否符合规则?}
    C -->|是| D[验证通过]
    C -->|否| E[抛出异常]

代码示例:

def validate_provider_credentials(self, credentials: dict) -> None:
    """
    Validate provider credentials
    """
    # 获取供应商凭据配置规则
    credential_schema = self.provider_credential_schema
    # 遍历每个凭据表单项进行验证
    for schema in credential_schema.credential_form_schemas:
        if schema.variable not in credentials:
            raise CredentialsValidateFailedError(f"Missing required credential: {schema.variable}")
        # 进行类型验证、必填验证等详细校验逻辑
        # ...

这段代码首先获取供应商的凭据配置规则,然后对每个凭据项进行详细验证,包括是否存在、类型是否正确、是否符合长度限制等。如果凭据不符合要求,则抛出 CredentialsValidateFailedError 异常。

3.1.2 模型凭据验证

模型凭据验证与供应商凭据验证类似,但在模型层进行。以下是模型凭据验证的简化流程:

graph TD
    A[开始] --> B[获取模型凭据配置规则]
    B --> C{凭据是否符合规则?}
    C -->|是| D[验证通过]
    C -->|否| E[抛出异常]

代码示例:

def validate_credentials(self, model: str, credentials: dict) -> None:
    """
    Validate model credentials
    """
    # 获取模型凭据配置规则
    credential_schema = self.model_credential_schema
    # 遍历每个凭据表单项进行验证
    for schema in credential_schema.credential_form_schemas:
        if schema.variable not in credentials:
            raise CredentialsValidateFailedError(f"Missing required credential: {schema.variable}")
        # 进行类型验证、必填验证等详细校验逻辑
        # ...

3.2 模型调用流程

模型调用是 model_runtime 模块的核心功能之一。以下以 LLM 模型调用为例,展示其完整的调用流程:

sequenceDiagram
    participant 应用层
    participant 工厂方法层
    participant 供应商层
    participant 模型层
    participant LLM模型

    应用层->>工厂方法层: 请求获取LLM模型实例
    工厂方法层->>供应商层: 根据供应商名称获取供应商实例
    供应商层->>模型层: 根据模型名称获取模型配置
    模型层->>LLM模型: 初始化LLM模型实例
    应用层->>LLM模型: 调用LLM模型的_invoke方法
    LLM模型->>LLM模型: 执行模型调用逻辑
    LLM模型-->>应用层: 返回调用结果(同步或流式)

代码示例:

def _invoke(
    self,
    model: str,
    credentials: dict,
    prompt_messages: list[PromptMessage],
    model_parameters: dict,
    tools: Optional[list[PromptMessageTool]] = None,
    stop: Optional[list[str]] = None,
    stream: bool = True,
    user: Optional[str] = None
) -> Union[LLMResult, Generator]:
    """
    Invoke large language model
    """
    try:
        # 根据是否流式返回选择不同的处理逻辑
        if stream:
            return self._handle_stream_response(
                model=model,
                credentials=credentials,
                prompt_messages=prompt_messages,
                model_parameters=model_parameters,
                tools=tools,
                stop=stop,
                user=user
            )
        else:
            return self._handle_sync_response(
                model=model,
                credentials=credentials,
                prompt_messages=prompt_messages,
                model_parameters=model_parameters,
                tools=tools,
                stop=stop,
                user=user
            )
    except Exception as e:
        # 异常处理逻辑,将模型异常映射到统一的InvokeError
        mapped_error = self._map_invoke_error(e)
        raise mapped_error from e

_invoke 方法中,根据是否需要流式返回,分别调用 _handle_stream_response_handle_sync_response 方法。流式返回时,方法会生成一个结果块生成器,逐块返回结果;同步返回时,直接返回完整的调用结果。同时,方法包含了异常处理逻辑,将模型调用异常映射到统一的 InvokeError 类型,便于上层应用处理。

3.3 模型参数管理

model_runtime 模块允许为不同模型定义个性化的参数规则。这些参数规则定义在 YAML 配置文件中,并通过 ParameterRule 类进行解析和验证。

YAML 配置示例:

parameter_rules:
- name: temperature
  use_template: temperature
  label:
    zh_Hans: 温度
    en_US: Temperature
- name: max_tokens
  use_template: max_tokens
  default: 1024
  min: 1
  max: 4096
  label:
    zh_Hans: 最大生成长度
    en_US: Max Tokens

代码解析:

class ParameterRule(BaseModel):
    """
    Model class for parameter rule.
    """
    name: str  # 参数名称
    use_template: Optional[str] = None  # 使用的模板名称
    label: I18nObject  # 参数标签,支持多语言
    type: ParameterType  # 参数类型
    help: Optional[I18nObject] = None  # 参数帮助信息
    required: bool = False  # 是否必填
    default: Optional[Any] = None  # 默认值
    min: Optional[float] = None  # 最小值(仅数字类型适用)
    max: Optional[float] = None  # 最大值(仅数字类型适用)
    precision: Optional[int] = None  # 精度(仅数字类型适用)
    options: list[str] = []  # 下拉选项值(仅字符串类型适用)

ParameterRule 类通过 Pydantic 的数据验证功能,确保每个参数都符合定义的规则。前端可以根据这些规则动态生成参数配置界面,用户在界面上调整参数后,后端可以直接使用这些参数进行模型调用。

四、扩展性支持

4.1 新增供应商

新增供应商主要分为以下几个步骤:

  1. 创建供应商 YAML 文件:根据供应商的信息和配置规则编写 YAML 配置文件。
  2. 创建供应商代码:实现一个继承自 ModelProvider 的供应商类,并实现必要的方法,如凭据验证方法。
  3. 创建模型类型模块:在供应商模块下创建对应模型类型的子模块,如 llm、text_embedding 等。
  4. 实现模型调用代码:在模型类型模块下实现具体的模型调用逻辑。
  5. 编写测试代码:确保新增供应商的功能正常。

供应商 YAML 配置示例:

provider: xinference
label:
  en_US: Xorbits Inference
icon_small:
  en_US: icon_s_en.svg
icon_large:
  en_US: icon_l_en.svg
supported_model_types:
- llm
- text-embedding
- rerank
configurate_methods:
- customizable-model
provider_credential_schema:
  credential_form_schemas:
  - variable: model_type
    type: select
    label:
      en_US: Model type
      zh_Hans: 模型类型
    required: true
    options:
    - value: text-generation
      label:
        en_US: Language Model
        zh_Hans: 语言模型
    - value: embeddings
      label:
        en_US: Text Embedding
    - value: reranking
      label:
        en_US: Rerank
  - variable: model_name
    type: text-input
    label:
      en_US: Model name
      zh_Hans: 模型名称
    required: true
    placeholder:
      zh_Hans: 填写模型名称
      en_US: Input model name
  - variable: server_url
    label:
      zh_Hans: 服务器 URL
      en_US: Server url
    type: text-input
    required: true
    placeholder:
      zh_Hans: 在此输入 Xinference 的服务器地址
      en_US: Enter the url of your Xinference
  - variable: model_uid
    label:
      zh_Hans: 模型 UID
      en_US: Model uid
    type: text-input
    required: true
    placeholder:
      zh_Hans: 在此输入您的 Model UID
      en_US: Enter the model uid

供应商类实现示例:

class XinferenceProvider(ModelProvider):
    """
    Provider class for Xinference.
    """
    def validate_provider_credentials(self, credentials: dict) -> None:
        """
        Validate provider credentials for Xinference.
        Since Xinference is a customizable model provider, this method is not actually used.
        """
        pass

4.2 新增模型

新增模型的过程根据模型是否为预定义模型有所不同。以下是预定义模型和自定义模型的新增流程对比:

特性预定义模型自定义模型
是否需要 YAML 配置需要,在模型类型模块下创建同名 YAML 文件不需要,模型配置通过代码动态生成
凭据配置位置在供应商 YAML 中定义统一凭据规则在模型调用代码中动态定义凭据规则
调用实现复杂度简单,直接调用供应商 API 获取模型列表和配置复杂,需要处理模型特有的参数和调用逻辑
扩展灵活性低,依赖供应商提供的 API 和预定义模型列表高,支持任意模型配置和调用逻辑

预定义模型 YAML 配置示例:

model: claude-2.1
label:
  en_US: claude-2.1
model_type: llm
features:
- agent-thought
model_properties:
  mode: chat
  context_size: 200000
parameter_rules:
- name: temperature
  use_template: temperature
- name: top_p
  use_template: top_p
- name: max_tokens_to_sample
  use_template: max_tokens
  default: 4096
  min: 1
  max: 4096
pricing:
  input: '8.00'
  output: '24.00'
  unit: '0.000001'
  currency: USD

自定义模型调用代码示例:

class XinferenceAILargeLanguageModel(LargeLanguageModel):
    """
    LLM class for Xinference.
    """
    def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
        """
        Invoke Xinference LLM.
        """
        # 获取模型调用所需凭据
        server_url = credentials.get("server_url")
        model_uid = credentials.get("model_uid")
        # 构建模型调用请求
        request_data = {
            "model_uid": model_uid,
            "prompt": prompt_messages,
            "parameters": model_parameters
        }
        # 发送请求到 Xinference 服务器
        response = requests.post(f"{server_url}/api/v1/models/invoke", json=request_data, stream=stream)
        # 处理响应结果
        if stream:
            return self._handle_stream_response(response, model, credentials, prompt_messages, model_parameters, tools, stop, user)
        else:
            return self._handle_sync_response(response, model, credentials, prompt_messages, model_parameters, tools, stop, user)

五、开发实践指南

5.1 调试技巧

在开发和调试 model_runtime 模块时,可以利用以下技巧提高效率:

  1. 使用断点调试:在 IDE 中设置断点,逐步执行代码,观察数据流向和变量状态。
  2. 日志记录:通过 LoggingCallback 类记录模型调用过程中的关键信息,便于定位问题。

日志记录示例:

class LoggingCallback(Callback):
    """
    Callback class for logging LLM invoke process.
    """
    def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None) -> None:
        """
        Log before invoke.
        """
        self.print_text("\n[on_llm_before_invoke]\n", color="blue")
        self.print_text(f"Model: {model}\n", color="blue")
        self.print_text("Parameters:\n", color="blue")
        for key, value in model_parameters.items():
            self.print_text(f"\t{key}: {value}\n", color="blue")
        # 更多日志记录逻辑...
  1. 单元测试:为每个模型和供应商编写单元测试,确保新增功能正常,旧功能不受影响。

5.2 性能优化建议

为了提高 model_runtime 模块的性能,可以采取以下措施:

  1. 异步调用:对于支持异步调用的模型,使用异步 IO 操作减少等待时间。
  2. 缓存机制:对频繁调用的模型结果或配置信息引入缓存,减少重复计算和网络请求。
  3. 模型参数优化:根据实际业务需求调整模型参数,如减少 max_tokens、调整 temperature 等,以降低计算量和响应时间。

5.3 常见问题解答

Q1: 如何处理模型调用超时问题?

A1: 可以在模型调用时设置超时参数,并在超时后重试或返回默认结果。例如,在发送 HTTP 请求时设置 timeout 参数:

response = requests.post(url, json=request_data, timeout=10)  # 超时时间为10秒

Q2: 如何为不同用户配置不同的模型凭据?

A2: 在应用层根据用户身份获取对应的凭据配置,然后在调用模型时传递相应的凭据字典。例如:

user_credentials = get_credentials_for_user(user_id)
model_result = model_instance.invoke(user_credentials, prompt_messages, model_parameters)

Q3: 如何扩展支持新的模型类型?

A3: 需要在模型层创建新的模型类,继承自 BaseModel 并实现特定的调用逻辑。同时,在工厂方法层注册新的模型类型,使其能够被应用层访问。

六、总结

Dify 的 model_runtime 模块通过清晰的分层架构和强大的功能设计,为开发者提供了一个高效、灵活且易于扩展的模型管理与调用解决方案。本文详细解析了模块的核心架构、功能实现和扩展方法,结合丰富的代码示例和图表,展示了 model_runtime 模块的运作原理和使用技巧。无论是对于初学者还是有经验的开发者,model_runtime 模块都提供了一个优秀的模型管理框架,帮助大家在人工智能应用开发中事半功倍。

希望本文能够帮助读者深入理解 Dify 的 model_runtime 模块,并在实际项目中灵活运用。如果您在使用过程中遇到任何问题或有新的见解,欢迎随时交流和分享。