【Dify(v1.x) 核心源码深入解析】Helper 模块

19 阅读16分钟

重磅推荐专栏: 《大模型AIGC》 《课程大纲》 《知识星球》

本专栏致力于探索和讨论当今最前沿的技术趋势和应用领域,包括但不限于ChatGPT和Stable Diffusion等。我们将深入研究大型模型的开发和应用,以及与之相关的人工智能生成内容(AIGC)技术。通过深入的技术解析和实践经验分享,旨在帮助读者更好地理解和应用这些领域的最新进展

一、引言

Dify 是一个功能强大的 AI 应用开发框架,其 helper 模块为开发者提供了丰富的工具函数和类,以简化开发过程并提高代码的可维护性和可扩展性。本文将深入剖析 Dify 的 helper 模块,涵盖从基础功能到高级实现的各个方面,通过详细的代码解读和示例,帮助读者全面掌握这些模块的使用方法。

二、模块概览

Dify 的 helper 模块主要包括以下几个方面:

  1. URL 签名模块:用于生成和验证带有签名的 URL。
  2. 缓存模块:包括工具提供商、工具参数、模型提供商和负载均衡模型的缓存操作。
  3. SSRF 代理请求模块:用于安全地进行代理请求,避免 SSRF 攻击。
  4. 位置帮助函数模块:提供对工具和提供商位置排序的功能。
  5. 模块导入帮助函数模块:简化从源文件导入模块的过程。
  6. 内容审核模块:用于检查文本内容是否符合审核标准。
  7. LRU 缓存模块:实现简单的 LRU(Least Recently Used)缓存机制。
  8. 加密解密模块:提供令牌的加密和解密功能。
  9. 下载与大小限制模块:用于下载文件并限制下载大小。

三、详细解读

1. URL 签名模块

功能概述

URL 签名模块主要用于生成带有签名的 URL,以确保 URL 的安全性和有效性。它通过使用 HMAC-SHA256 算法对特定数据进行签名,并将签名结果附加到 URL 查询参数中。验证时,再次计算签名并与提供的签名进行比较。

关键类和方法

  • SignedUrlParams:一个 Pydantic 模型,定义了签名 URL 参数的结构,包括签名密钥(sign_key)、时间戳(timestamp)、随机数(nonce)和签名(sign)。
class SignedUrlParams(BaseModel):
    sign_key: str = Field(..., description="The sign key")
    timestamp: str = Field(..., description="Timestamp")
    nonce: str = Field(..., description="Nonce")
    sign: str = Field(..., description="Signature")
  • UrlSigner:一个工具类,提供了生成签名 URL 和验证签名的方法。
class UrlSigner:
    @classmethod
    def get_signed_url(cls, url: str, sign_key: str, prefix: str) -> str:
        signed_url_params = cls.get_signed_url_params(sign_key, prefix)
        return (
            f"{url}?timestamp={signed_url_params.timestamp}"
            f"&nonce={signed_url_params.nonce}&sign={signed_url_params.sign}"
        )

    @classmethod
    def get_signed_url_params(cls, sign_key: str, prefix: str) -> SignedUrlParams:
        timestamp = str(int(time.time()))
        nonce = os.urandom(16).hex()
        sign = cls._sign(sign_key, timestamp, nonce, prefix)

        return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign)

    @classmethod
    def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool:
        recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix)

        return sign == recalculated_sign

    @classmethod
    def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str:
        if not dify_config.SECRET_KEY:
            raise Exception("SECRET_KEY is not set")

        data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}"
        secret_key = dify_config.SECRET_KEY.encode()
        sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
        encoded_sign = base64.urlsafe_b64encode(sign).decode()

        return encoded_sign

使用示例

signed_url = UrlSigner.get_signed_url("https://example.com/api", "my_sign_key", "api")
print(signed_url)

# 验证签名
is_valid = UrlSigner.verify(
    "my_sign_key",
    "1696784400",
    "random_nonce",
    "valid_signature",
    "api"
)
print(is_valid)

关键点解释

  • _sign 方法使用 HMAC-SHA256 算法对指定数据进行签名,并将结果进行 Base64 编码。
  • get_signed_url_params 方法生成时间戳、随机数和签名,并返回一个 SignedUrlParams 对象。
  • get_signed_url 方法将签名参数附加到 URL 查询字符串中。
  • verify 方法重新计算签名并将其与提供的签名进行比较,以验证 URL 的有效性。

流程图

graph TD
    A[开始] --> B[获取时间戳]
    B --> C[生成随机数]
    C --> D[生成签名数据]
    D --> E[使用 HMAC-SHA256 计算签名]
    E --> F[Base64 编码签名]
    F --> G[创建 SignedUrlParams 对象]
    G --> H[附加签名参数到 URL]
    H --> I[返回签名 URL]

2. 缓存模块

功能概述

缓存模块提供了对工具提供商、工具参数、模型提供商和负载均衡模型的缓存操作。这些类利用 Redis 进行数据缓存,提高数据访问的效率。

关键类

  • ToolProviderCredentialsCache:用于缓存工具提供商的凭据。
class ToolProviderCredentialsCache:
    def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
        self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"

    def get(self) -> Optional[dict]:
        cached_provider_credentials = redis_client.get(self.cache_key)
        if cached_provider_credentials:
            try:
                cached_provider_credentials = cached_provider_credentials.decode("utf-8")
                cached_provider_credentials = json.loads(cached_provider_credentials)
            except JSONDecodeError:
                return None

            return dict(cached_provider_credentials)
        else:
            return None

    def set(self, credentials: dict) -> None:
        redis_client.setex(self.cache_key, 86400, json.dumps(credentials))

    def delete(self) -> None:
        redis_client.delete(self.cache_key)
  • ToolParameterCache:用于缓存工具参数。
class ToolParameterCache:
    def __init__(
        self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str
    ):
        self.cache_key = (
            f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
            f":identity_id:{identity_id}"
        )

    def get(self) -> Optional[dict]:
        cached_tool_parameter = redis_client.get(self.cache_key)
        if cached_tool_parameter:
            try:
                cached_tool_parameter = cached_tool_parameter.decode("utf-8")
                cached_tool_parameter = json.loads(cached_tool_parameter)
            except JSONDecodeError:
                return None

            return dict(cached_tool_parameter)
        else:
            return None

    def set(self, parameters: dict) -> None:
        redis_client.setex(self.cache_key, 86400, json.dumps(parameters))

    def delete(self) -> None:
        redis_client.delete(self.cache_key)
  • ProviderCredentialsCache:用于缓存模型提供商的凭据。
class ProviderCredentialsCache:
    def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
        self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"

    def get(self) -> Optional[dict]:
        cached_provider_credentials = redis_client.get(self.cache_key)
        if cached_provider_credentials:
            try:
                cached_provider_credentials = cached_provider_credentials.decode("utf-8")
                cached_provider_credentials = json.loads(cached_provider_credentials)
            except JSONDecodeError:
                return None

            return dict(cached_provider_credentials)
        else:
            return None

    def set(self, credentials: dict) -> None:
        redis_client.setex(self.cache_key, 86400, json.dumps(credentials))

    def delete(self) -> None:
        redis_client.delete(self.cache_key)

使用示例

# 工具提供商缓存
tool_provider_cache = ToolProviderCredentialsCache("tenant_1", "identity_1", ToolProviderCredentialsCacheType.PROVIDER)
tool_provider_cache.set({"key": "value"})
print(tool_provider_cache.get())
tool_provider_cache.delete()

# 工具参数缓存
tool_parameter_cache = ToolParameterCache("tenant_1", "provider_1", "tool_1", ToolParameterCacheType.PARAMETER, "identity_1")
tool_parameter_cache.set({"param": "value"})
print(tool_parameter_cache.get())
tool_parameter_cache.delete()

# 模型提供商缓存
provider_cache = ProviderCredentialsCache("tenant_1", "identity_1", ProviderCredentialsCacheType.PROVIDER)
provider_cache.set({"key": "value"})
print(provider_cache.get())
provider_cache.delete()

关键点解释

  • 每个缓存类都具有 getsetdelete 方法,分别用于获取、设置和删除缓存数据。
  • 缓存键(cache_key)根据传入的参数动态生成,确保每个缓存项都有唯一的键。
  • 使用 Redis 的 setex 方法设置缓存,带有过期时间(86400 秒,即 1 天)。
  • 获取缓存时,会尝试解码和解析 JSON 数据,如果解析失败则返回 None

类图

classDiagram
    class ToolProviderCredentialsCache {
        +cache_key: str
        +get() -> Optional[dict]
        +set(credentials: dict) -> None
        +delete() -> None
    }

    class ToolParameterCache {
        +cache_key: str
        +get() -> Optional[dict]
        +set(parameters: dict) -> None
        +delete() -> None
    }

    class ProviderCredentialsCache {
        +cache_key: str
        +get() -> Optional[dict]
        +set(credentials: dict) -> None
        +delete() -> None
    }

    ToolProviderCredentialsCache <|-- ToolParameterCache
    ToolProviderCredentialsCache <|-- ProviderCredentialsCache

3. SSRF 代理请求模块

功能概述

SSRF(Server-Side Request Forgery)代理请求模块用于安全地进行代理请求,避免 SSRF 攻击。它通过设置代理服务器、重试机制和超时控制,确保请求的安全性和可靠性。

关键方法

def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
    if "allow_redirects" in kwargs:
        allow_redirects = kwargs.pop("allow_redirects")
        if "follow_redirects" not in kwargs:
            kwargs["follow_redirects"] = allow_redirects

    if "timeout" not in kwargs:
        kwargs["timeout"] = httpx.Timeout(
            timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
            connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
            read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
            write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
        )

    if "ssl_verify" not in kwargs:
        kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY

    ssl_verify = kwargs.pop("ssl_verify")

    retries = 0
    while retries <= max_retries:
        try:
            if dify_config.SSRF_PROXY_ALL_URL:
                with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=ssl_verify) as client:
                    response = client.request(method=method, url=url, **kwargs)
            elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
                proxy_mounts = {
                    "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=ssl_verify),
                    "https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=ssl_verify),
                }
                with httpx.Client(mounts=proxy_mounts, verify=ssl_verify) as client:
                    response = client.request(method=method, url=url, **kwargs)
            else:
                with httpx.Client(verify=ssl_verify) as client:
                    response = client.request(method=method, url=url, **kwargs)

            if response.status_code not in STATUS_FORCELIST:
                return response
            else:
                logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list")

        except httpx.RequestError as e:
            logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}")
            if max_retries == 0:
                raise

        retries += 1
        if retries <= max_retries:
            time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
    raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")


def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
    return make_request("GET", url, max_retries=max_retries, **kwargs)


def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
    return make_request("POST", url, max_retries=max_retries, **kwargs)


def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
    return make_request("PUT", url, max_retries=max_retries, **kwargs)


def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
    return make_request("PATCH", url, max_retries=max_retries, **kwargs)


def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
    return make_request("DELETE", url, max_retries=max_retries, **kwargs)


def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
    return make_request("HEAD", url, max_retries=max_retries, **kwargs)

使用示例

response = make_request("GET", "https://example.com/api/data")
print(response.status_code)
print(response.json())

response = get("https://example.com/api/data")
print(response.status_code)
print(response.json())

关键点解释

  • make_request 方法是核心方法,用于发送 HTTP 请求。它支持多种 HTTP 方法(GET、POST、PUT、PATCH、DELETE、HEAD)。
  • 该方法实现了重试机制,当请求失败时,会根据 max_retries 参数进行重试,并使用指数退避算法增加重试间隔。
  • 支持设置代理服务器(通过 SSRF_PROXY_ALL_URLSSRF_PROXY_HTTP_URLSSRF_PROXY_HTTPS_URL 配置)。
  • 自动处理超时设置,包括连接超时、读取超时和写入超时。
  • 如果响应状态码在 STATUS_FORCELIST 列表中(如 429、500、502、503、504),会记录警告日志并继续重试。

流程图

graph TD
    A[开始] --> B[检查重定向设置]
    B --> C[设置超时]
    C --> D[设置 SSL 验证]
    D --> E{是否配置代理}
    E -->|是| F[使用代理发送请求]
    E -->|否| G[发送请求]
    F --> H{请求是否成功}
    G --> H
    H -->|是| I[返回响应]
    H -->|否| J[记录警告日志]
    J --> K{重试次数是否用尽}
    K -->|否| L[指数退避,重试]
    K -->|是| M[抛出异常]

时序图

sequenceDiagram
    participant Client
    participant SSRFProxy

    Client->>SSRFProxy: make_request("GET", "https://example.com/api/data")
    SSRFProxy->>SSRFProxy: 设置超时和 SSL 验证
    SSRFProxy->>SSRFProxy: 检查代理配置
    SSRFProxy->>SSRFProxy: 使用代理发送请求
    SSRFProxy->>SSRFProxy: 检查响应状态码
    alt 状态码不在强制列表中
        SSRFProxy-->>Client: 返回响应
    else
        SSRFProxy->>SSRFProxy: 记录警告日志
        SSRFProxy->>SSRFProxy: 检查重试次数
        loop 重试
            SSRFProxy->>SSRFProxy: 指数退避等待
            SSRFProxy->>SSRFProxy: 重新发送请求
        end
        SSRFProxy-->>Client: 抛出 MaxRetriesExceededError
    end

4. 位置帮助函数模块

功能概述

位置帮助函数模块提供了一系列函数,用于从 YAML 文件中获取位置映射,并根据这些映射对对象进行排序和过滤。

关键函数

def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]:
    position_file_path = os.path.join(folder_path, file_name)
    yaml_content = load_yaml_file(file_path=position_file_path, default_value=[])
    positions = [item.strip() for item in yaml_content if item and isinstance(item, str) and item.strip()]
    return {name: index for index, name in enumerate(positions)}


def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
    position_map = get_position_map(folder_path, file_name=file_name)
    return pin_position_map(
        position_map,
        pin_list=dify_config.POSITION_TOOL_PINS_LIST,
    )


def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
    position_map = get_position_map(folder_path, file_name=file_name)
    return pin_position_map(
        position_map,
        pin_list=dify_config.POSITION_PROVIDER_PINS_LIST,
    )


def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]:
    positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x])

    position_map = {name: idx for idx, name in enumerate(pin_list)}

    start_idx = len(position_map)
    for name in positions:
        if name not in position_map:
            position_map[name] = start_idx
            start_idx += 1

    return position_map


def is_filtered(
    include_set: set[str],
    exclude_set: set[str],
    data: Any,
    name_func: Callable[[Any], str],
) -> bool:
    if not data:
        return False
    if not include_set and not exclude_set:
        return False

    name = name_func(data)

    if name in exclude_set:  # exclude_set 优先级更高
        return True
    if include_set and name not in include_set:  # 如果 include_set 不为空,则仅包含其中的项
        return True
    return False


def sort_by_position_map(
    position_map: dict[str, int],
    data: list[Any],
    name_func: Callable[[Any], str],
) -> list[Any]:
    if not position_map or not data:
        return data

    return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf")))


def sort_to_dict_by_position_map(
    position_map: dict[str, int],
    data: list[Any],
    name_func: Callable[[Any], str],
) -> OrderedDict[str, Any]:
    sorted_items = sort_by_position_map(position_map, data, name_func)
    return OrderedDict([(name_func(item), item) for item in sorted_items])

使用示例

# 获取工具位置映射
tool_position_map = get_tool_position_map("/path/to/tools/folder")
print(tool_position_map)

# 获取提供商位置映射
provider_position_map = get_provider_position_map("/path/to/providers/folder")
print(provider_position_map)

# 排序工具列表
tools = [{"name": "tool1"}, {"name": "tool2"}, {"name": "tool3"}]
sorted_tools = sort_by_position_map(tool_position_map, tools, lambda x: x["name"])
print([tool["name"] for tool in sorted_tools])

# 排序并转换为有序字典
sorted_tools_dict = sort_to_dict_by_position_map(tool_position_map, tools, lambda x: x["name"])
print(sorted_tools_dict)

关键点解释

  • get_position_map 函数从指定文件夹中的 YAML 文件加载位置列表,并将其转换为字典,键为名称,值为索引。
  • get_tool_position_mapget_provider_position_map 函数分别获取工具和提供商的位置映射,并应用固定列表(pin_list)。
  • pin_position_map 函数将固定列表中的项放置在位置映射的开头。
  • is_filtered 函数用于根据包含集和排除集对对象进行过滤。
  • sort_by_position_map 函数根据位置映射对对象列表进行排序。
  • sort_to_dict_by_position_map 函数将排序后的对象列表转换为有序字典。

类图

classDiagram
    class PositionHelper {
        +get_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]
        +get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]
        +get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]
        +pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]
        +is_filtered(include_set: set[str], exclude_set: set[str], data: Any, name_func: Callable[[Any], str]) -> bool
        +sort_by_position_map(position_map: dict[str, int], data: list[Any], name_func: Callable[[Any], str]) -> list[Any]
        +sort_to_dict_by_position_map(position_map: dict[str, int], data: list[Any], name_func: Callable[[Any], str]) -> OrderedDict[str, Any]
    }

流程图

graph TD
    A[开始] --> B[加载 YAML 文件]
    B --> C[解析 YAML 内容]
    C --> D[生成位置映射]
    D --> E[应用固定列表]
    E --> F[返回位置映射]

5. 模块导入帮助函数模块

功能概述

模块导入帮助函数模块提供了从源文件导入模块的功能,并支持动态加载子类。

关键函数

def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
    try:
        existed_spec = importlib.util.find_spec(module_name)
        if existed_spec:
            spec = existed_spec
            if not spec.loader:
                raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
        else:
            spec = importlib.util.spec_from_file_location(module_name, py_file_path)
            if not spec or not spec.loader:
                raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
            if use_lazy_loader:
                spec.loader = importlib.util.LazyLoader(spec.loader)
        module = importlib.util.module_from_spec(spec)
        if not existed_spec:
            sys.modules[module_name] = module
        spec.loader.exec_module(module)
        return module
    except Exception as e:
        logging.exception(f"Failed to load module {module_name} from script file '{py_file_path!r}'")
        raise e


def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]:
    classes = [
        x for _, x in vars(mod).items() if isinstance(x, type) and x != parent_type and issubclass(x, parent_type)
    ]
    return classes


def load_single_subclass_from_source(
    *, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
) -> type:
    module = import_module_from_source(
        module_name=module_name, py_file_path=script_path, use_lazy_loader=use_lazy_loader
    )
    subclasses = get_subclasses_from_module(module, parent_type)
    match len(subclasses):
        case 1:
            return subclasses[0]
        case 0:
            raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path!r}")
        case _:
            raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path!r}")

使用示例

# 从源文件导入模块
module = import_module_from_source(module_name="my_module", py_file_path="/path/to/my_module.py")

# 获取模块中的子类
subclasses = get_subclasses_from_module(module, ParentClass)
print(subclasses)

# 加载单个子类
subclass = load_single_subclass_from_source(
    module_name="my_module",
    script_path="/path/to/my_module.py",
    parent_type=ParentClass
)
print(subclass)

关键点解释

  • import_module_from_source 函数从指定的 Python 文件导入模块。它可以使用延迟加载器(use_lazy_loader)来延迟模块的加载。
  • get_subclasses_from_module 函数从模块中获取指定父类的所有子类。
  • load_single_subclass_from_source 函数从源文件中加载单个子类。如果找不到子类或找到多个子类,会抛出异常。

类图

classDiagram
    class ModuleImportHelper {
        +import_module_from_source(module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType
        +get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
        +load_single_subclass_from_source(module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False) -> type
    }

流程图

graph TD
    A[开始] --> B[检查模块是否已存在]
    B -->|是| C[使用现有模块规范]
    B -->|否| D[从文件位置创建模块规范]
    C --> E[检查加载器是否存在]
    D --> E
    E -->|加载器不存在| F[抛出异常]
    E -->|加载器存在| G[创建模块]
    G --> H[执行模块]
    H --> I[返回模块]

6. 内容审核模块

功能概述

内容审核模块用于检查文本内容是否符合审核标准。它利用 OpenAI 的审核模型对文本进行分析,并返回审核结果。

关键函数

def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
    moderation_config = hosting_configuration.moderation_config
    openai_provider_name = f"{DEFAULT_PLUGIN_ID}/openai/openai"
    if (
        moderation_config
        and moderation_config.enabled is True
        and openai_provider_name in hosting_configuration.provider_map
        and hosting_configuration.provider_map[openai_provider_name].enabled is True
    ):
        using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type
        provider_name = model_config.provider
        if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers:
            hosting_openai_config = hosting_configuration.provider_map[openai_provider_name]

            if hosting_openai_config.credentials is None:
                return False

            length = 2000
            text_chunks = [text[i : i + length] for i in range(0, len(text), length)]

            if len(text_chunks) == 0:
                return True

            text_chunk = random.choice(text_chunks)

            try:
                model_provider_factory = ModelProviderFactory(tenant_id)

                model_type_instance = model_provider_factory.get_model_type_instance(
                    provider=openai_provider_name, model_type=ModelType.MODERATION
                )
                model_type_instance = cast(ModerationModel, model_type_instance)
                moderation_result = model_type_instance.invoke(
                    model="omni-moderation-latest", credentials=hosting_openai_config.credentials, text=text_chunk
                )

                if moderation_result is True:
                    return True
            except Exception:
                logger.exception(f"Fails to check moderation, provider_name: {provider_name}")
                raise InvokeBadRequestError("Rate limit exceeded, please try again later.")

    return False

使用示例

# 检查文本内容是否符合审核标准
is_moderated = check_moderation(
    tenant_id="tenant_1",
    model_config=ModelConfigWithCredentialsEntity(
        provider_model_bundle=ModelProviderModelBundle(
            configuration=ModelProviderModelConfiguration(
                using_provider_type=ProviderType.SYSTEM
            )
        ),
        provider="provider_1"
    ),
    text="This is a sample text to check moderation."
)
print(is_moderated)

关键点解释

  • check_moderation 函数检查审核配置是否启用,并确保 OpenAI 提供商配置正确。
  • 它将文本分块(每块 2000 个字符),并随机选择一块进行审核。
  • 使用 ModelProviderFactory 创建审核模型实例,并调用其 invoke 方法进行审核。
  • 如果审核结果为 True,则返回 True,表示内容符合审核标准。
  • 如果审核过程中发生异常,会记录日志并抛出 InvokeBadRequestError

流程图

graph TD
    A[开始] --> B[检查审核配置]
    B -->|审核配置未启用| C[返回 False]
    B -->|审核配置启用| D[检查 OpenAI 提供商配置]
    D -->|配置无效| C
    D --> E[检查提供商类型]
    E -->|非系统提供商| C
    E --> F[检查提供商是否在审核列表中]
    F -->|不在列表中| C
    F --> G[获取 OpenAI 配置]
    G -->|凭证不存在| C
    G --> H[分块文本]
    H --> I[选择随机文本块]
    I --> J[创建审核模型实例]
    J --> K[调用审核模型]
    K --> L{审核结果}
    L -->|True| M[返回 True]
    L -->|False| C
    K -->|异常| N[记录日志并抛出异常]

7. LRU 缓存模块

功能概述

LRU 缓存模块实现了一个简单的 LRU(Least Recently Used)缓存机制,用于存储和管理缓存项。

关键类

class LRUCache:
    def __init__(self, capacity: int):
        self.cache: OrderedDict[Any, Any] = OrderedDict()
        self.capacity = capacity

    def get(self, key: Any) -> Any:
        if key not in self.cache:
            return None
        else:
            self.cache.move_to_end(key)  # 将键移动到有序字典的末尾
            return self.cache[key]

    def put(self, key: Any, value: Any) -> None:
        if key in self.cache:
            self.cache.move_to_end(key)
        self.cache[key] = value
        if len(self.cache) > self.capacity:
            self.cache.popitem(last=False)  # 弹出第一个项

使用示例

# 创建 LRU 缓存实例
lru_cache = LRUCache(capacity=3)

# 添加缓存项
lru_cache.put("key1", "value1")
lru_cache.put("key2", "value2")
lru_cache.put("key3", "value3")

# 获取缓存项
print(lru_cache.get("key1"))  # 输出: value1

# 添加新缓存项(超出容量)
lru_cache.put("key4", "value4")
print(lru_cache.get("key2"))  # 输出: None(已被淘汰)

关键点解释

  • LRUCache 类使用 OrderedDict 来存储缓存项,以便跟踪最近使用的项。
  • get 方法检索缓存项,并将其移动到有序字典的末尾(表示最近使用)。
  • put 方法添加或更新缓存项,并将其移动到有序字典的末尾。如果缓存超出容量,会从开头弹出最久未使用的项。

类图

classDiagram
    class LRUCache {
        +cache: OrderedDict[Any, Any]
        +capacity: int
        +__init__(capacity: int) -> None
        +get(key: Any) -> Any
        +put(key: Any, value: Any) -> None
    }

流程图

graph TD
    A[开始] --> B[检查键是否存在]
    B -->|存在| C[移动到末尾]
    C --> D[返回值]
    B -->|不存在| E[返回 None]
    A --> F[添加键值对]
    F --> G[检查键是否存在]
    G -->|存在| H[移动到末尾]
    G -->|不存在| I[添加键值对]
    I --> J[检查容量]
    J -->|超过容量| K[弹出第一个项]
    J -->|未超过容量| L[完成]

8. 加密解密模块

功能概述

加密解密模块提供了对令牌的加密和解密功能,确保数据的安全性。

关键函数

def obfuscated_token(token: str):
    if not token:
        return token
    if len(token) <= 8:
        return "*" * 20
    return token[:6] + "*" * 12 + token[-2:]


def encrypt_token(tenant_id: str, token: str):
    from models.account import Tenant
    from models.engine import db

    if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
        raise ValueError(f"Tenant with id {tenant_id} not found")
    encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
    return base64.b64encode(encrypted_token).decode()


def decrypt_token(tenant_id: str, token: str):
    return rsa.decrypt(base64.b64decode(token), tenant_id)


def batch_decrypt_token(tenant_id: str, tokens: list[str]):
    rsa_key, cipher_rsa = rsa.get_decrypt_decoding(tenant_id)

    return [rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa) for token in tokens]


def get_decrypt_decoding(tenant_id: str):
    return rsa.get_decrypt_decoding(tenant_id)


def decrypt_token_with_decoding(token: str, rsa_key, cipher_rsa):
    return rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa)

使用示例

# 加密令牌
encrypted_token = encrypt_token("tenant_1", "plain_text_token")
print(encrypted_token)

# 解密令牌
decrypted_token = decrypt_token("tenant_1", encrypted_token)
print(decrypted_token)

# 批量解密令牌
decrypted_tokens = batch_decrypt_token("tenant_1", [encrypted_token, encrypted_token])
print(decrypted_tokens)

关键点解释

  • obfuscated_token 函数对令牌进行混淆处理,隐藏大部分内容。
  • encrypt_token 函数使用 RSA 加密算法对令牌进行加密,并将其编码为 Base64。
  • decrypt_token 函数使用 RSA 解密算法对令牌进行解密。
  • batch_decrypt_token 函数批量解密令牌列表。
  • get_decrypt_decodingdecrypt_token_with_decoding 函数提供了解密的辅助功能。

流程图

graph TD
    A[开始] --> B{令牌长度 <= 8}
    B -->|是| C[返回掩码字符串]
    B -->|否| D[返回部分掩码令牌]
    A --> E[加密令牌]
    E --> F[获取租户]
    F -->|租户不存在| G[抛出异常]
    F -->|租户存在| H[加密令牌]
    H --> I[Base64 编码]
    I --> J[返回加密令牌]
    A --> K[解密令牌]
    K --> L[Base64 解码]
    L --> M[解密令牌]
    M --> N[返回解密令牌]

9. 下载与大小限制模块

功能概述

下载与大小限制模块用于下载文件,并限制下载大小,以防止下载过大的文件。

关键函数

from core.helper import ssrf_proxy


def download_with_size_limit(url, max_download_size: int, **kwargs):
    response = ssrf_proxy.get(url, follow_redirects=True, **kwargs)
    if response.status_code == 404:
        raise ValueError("file not found")

    total_size = 0
    chunks = []
    for chunk in response.iter_bytes():
        total_size += len(chunk)
        if total_size > max_download_size:
            raise ValueError("Max file size reached")
        chunks.append(chunk)
    content = b"".join(chunks)
    return content

使用示例

# 下载文件并限制大小
content = download_with_size_limit("https://example.com/file.zip", 1024 * 1024)  # 限制为 1MB
print(len(content))

关键点解释

  • download_with_size_limit 函数使用 SSRF 代理模块的 get 方法下载文件。
  • 它检查响应状态码,如果为 404,则抛出异常。
  • 通过迭代响应的字节块,逐步下载文件内容,同时跟踪总下载大小。
  • 如果下载大小超过指定限制,会抛出 ValueError
  • 最后,将所有字节块合并成一个字节对象并返回。

流程图

graph TD
    A[开始] --> B[发送 GET 请求]
    B --> C{状态码 == 404}
    C -->|是| D[抛出异常]
    C -->|否| E[初始化变量]
    E --> F[迭代响应字节块]
    F --> G[累加块大小]
    G --> H{总大小 > 限制}
    H -->|是| I[抛出异常]
    H -->|否| J[添加块到列表]
    J --> F
    F --> K[合并块]
    K --> L[返回内容]

四、总结

Dify 的 helper 模块提供了一系列功能强大的工具,涵盖了从 URL 签名、缓存操作、安全代理请求、位置排序、模块导入、内容审核、LRU 缓存、加密解密到文件下载等多个方面。这些模块的设计旨在简化开发过程,提高代码的可维护性和可扩展性。

通过本文的详细解读,我们希望读者能够深入理解每个模块的功能和实现原理,并能够灵活运用这些工具来构建高效、安全的 AI 应用。无论是初学者还是有经验的开发者,都可以从这些模块中获益,加速项目的开发和部署。

在未来的发展中,Dify 的 helper 模块预计会不断扩展和优化,以满足不断变化的技术需求和应用场景。开发者可以持续关注这些模块的更新,以便充分利用其新功能和改进。