本专栏致力于探索和讨论当今最前沿的技术趋势和应用领域,包括但不限于ChatGPT和Stable Diffusion等。我们将深入研究大型模型的开发和应用,以及与之相关的人工智能生成内容(AIGC)技术。通过深入的技术解析和实践经验分享,旨在帮助读者更好地理解和应用这些领域的最新进展
一、引言
Dify 是一个功能强大的 AI 应用开发框架,其 helper 模块为开发者提供了丰富的工具函数和类,以简化开发过程并提高代码的可维护性和可扩展性。本文将深入剖析 Dify 的 helper 模块,涵盖从基础功能到高级实现的各个方面,通过详细的代码解读和示例,帮助读者全面掌握这些模块的使用方法。
二、模块概览
Dify 的 helper 模块主要包括以下几个方面:
- URL 签名模块:用于生成和验证带有签名的 URL。
- 缓存模块:包括工具提供商、工具参数、模型提供商和负载均衡模型的缓存操作。
- SSRF 代理请求模块:用于安全地进行代理请求,避免 SSRF 攻击。
- 位置帮助函数模块:提供对工具和提供商位置排序的功能。
- 模块导入帮助函数模块:简化从源文件导入模块的过程。
- 内容审核模块:用于检查文本内容是否符合审核标准。
- LRU 缓存模块:实现简单的 LRU(Least Recently Used)缓存机制。
- 加密解密模块:提供令牌的加密和解密功能。
- 下载与大小限制模块:用于下载文件并限制下载大小。
三、详细解读
1. URL 签名模块
功能概述
URL 签名模块主要用于生成带有签名的 URL,以确保 URL 的安全性和有效性。它通过使用 HMAC-SHA256 算法对特定数据进行签名,并将签名结果附加到 URL 查询参数中。验证时,再次计算签名并与提供的签名进行比较。
关键类和方法
- SignedUrlParams:一个 Pydantic 模型,定义了签名 URL 参数的结构,包括签名密钥(
sign_key
)、时间戳(timestamp
)、随机数(nonce
)和签名(sign
)。
class SignedUrlParams(BaseModel):
sign_key: str = Field(..., description="The sign key")
timestamp: str = Field(..., description="Timestamp")
nonce: str = Field(..., description="Nonce")
sign: str = Field(..., description="Signature")
- UrlSigner:一个工具类,提供了生成签名 URL 和验证签名的方法。
class UrlSigner:
@classmethod
def get_signed_url(cls, url: str, sign_key: str, prefix: str) -> str:
signed_url_params = cls.get_signed_url_params(sign_key, prefix)
return (
f"{url}?timestamp={signed_url_params.timestamp}"
f"&nonce={signed_url_params.nonce}&sign={signed_url_params.sign}"
)
@classmethod
def get_signed_url_params(cls, sign_key: str, prefix: str) -> SignedUrlParams:
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
sign = cls._sign(sign_key, timestamp, nonce, prefix)
return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign)
@classmethod
def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool:
recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix)
return sign == recalculated_sign
@classmethod
def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str:
if not dify_config.SECRET_KEY:
raise Exception("SECRET_KEY is not set")
data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return encoded_sign
使用示例
signed_url = UrlSigner.get_signed_url("https://example.com/api", "my_sign_key", "api")
print(signed_url)
# 验证签名
is_valid = UrlSigner.verify(
"my_sign_key",
"1696784400",
"random_nonce",
"valid_signature",
"api"
)
print(is_valid)
关键点解释
_sign
方法使用 HMAC-SHA256 算法对指定数据进行签名,并将结果进行 Base64 编码。get_signed_url_params
方法生成时间戳、随机数和签名,并返回一个SignedUrlParams
对象。get_signed_url
方法将签名参数附加到 URL 查询字符串中。verify
方法重新计算签名并将其与提供的签名进行比较,以验证 URL 的有效性。
流程图
graph TD
A[开始] --> B[获取时间戳]
B --> C[生成随机数]
C --> D[生成签名数据]
D --> E[使用 HMAC-SHA256 计算签名]
E --> F[Base64 编码签名]
F --> G[创建 SignedUrlParams 对象]
G --> H[附加签名参数到 URL]
H --> I[返回签名 URL]
2. 缓存模块
功能概述
缓存模块提供了对工具提供商、工具参数、模型提供商和负载均衡模型的缓存操作。这些类利用 Redis 进行数据缓存,提高数据访问的效率。
关键类
- ToolProviderCredentialsCache:用于缓存工具提供商的凭据。
class ToolProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> Optional[dict]:
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode("utf-8")
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None
return dict(cached_provider_credentials)
else:
return None
def set(self, credentials: dict) -> None:
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
def delete(self) -> None:
redis_client.delete(self.cache_key)
- ToolParameterCache:用于缓存工具参数。
class ToolParameterCache:
def __init__(
self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str
):
self.cache_key = (
f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
f":identity_id:{identity_id}"
)
def get(self) -> Optional[dict]:
cached_tool_parameter = redis_client.get(self.cache_key)
if cached_tool_parameter:
try:
cached_tool_parameter = cached_tool_parameter.decode("utf-8")
cached_tool_parameter = json.loads(cached_tool_parameter)
except JSONDecodeError:
return None
return dict(cached_tool_parameter)
else:
return None
def set(self, parameters: dict) -> None:
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
def delete(self) -> None:
redis_client.delete(self.cache_key)
- ProviderCredentialsCache:用于缓存模型提供商的凭据。
class ProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> Optional[dict]:
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode("utf-8")
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None
return dict(cached_provider_credentials)
else:
return None
def set(self, credentials: dict) -> None:
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
def delete(self) -> None:
redis_client.delete(self.cache_key)
使用示例
# 工具提供商缓存
tool_provider_cache = ToolProviderCredentialsCache("tenant_1", "identity_1", ToolProviderCredentialsCacheType.PROVIDER)
tool_provider_cache.set({"key": "value"})
print(tool_provider_cache.get())
tool_provider_cache.delete()
# 工具参数缓存
tool_parameter_cache = ToolParameterCache("tenant_1", "provider_1", "tool_1", ToolParameterCacheType.PARAMETER, "identity_1")
tool_parameter_cache.set({"param": "value"})
print(tool_parameter_cache.get())
tool_parameter_cache.delete()
# 模型提供商缓存
provider_cache = ProviderCredentialsCache("tenant_1", "identity_1", ProviderCredentialsCacheType.PROVIDER)
provider_cache.set({"key": "value"})
print(provider_cache.get())
provider_cache.delete()
关键点解释
- 每个缓存类都具有
get
、set
和delete
方法,分别用于获取、设置和删除缓存数据。 - 缓存键(
cache_key
)根据传入的参数动态生成,确保每个缓存项都有唯一的键。 - 使用 Redis 的
setex
方法设置缓存,带有过期时间(86400 秒,即 1 天)。 - 获取缓存时,会尝试解码和解析 JSON 数据,如果解析失败则返回
None
。
类图
classDiagram
class ToolProviderCredentialsCache {
+cache_key: str
+get() -> Optional[dict]
+set(credentials: dict) -> None
+delete() -> None
}
class ToolParameterCache {
+cache_key: str
+get() -> Optional[dict]
+set(parameters: dict) -> None
+delete() -> None
}
class ProviderCredentialsCache {
+cache_key: str
+get() -> Optional[dict]
+set(credentials: dict) -> None
+delete() -> None
}
ToolProviderCredentialsCache <|-- ToolParameterCache
ToolProviderCredentialsCache <|-- ProviderCredentialsCache
3. SSRF 代理请求模块
功能概述
SSRF(Server-Side Request Forgery)代理请求模块用于安全地进行代理请求,避免 SSRF 攻击。它通过设置代理服务器、重试机制和超时控制,确保请求的安全性和可靠性。
关键方法
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
if "follow_redirects" not in kwargs:
kwargs["follow_redirects"] = allow_redirects
if "timeout" not in kwargs:
kwargs["timeout"] = httpx.Timeout(
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
)
if "ssl_verify" not in kwargs:
kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY
ssl_verify = kwargs.pop("ssl_verify")
retries = 0
while retries <= max_retries:
try:
if dify_config.SSRF_PROXY_ALL_URL:
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=ssl_verify),
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=ssl_verify),
}
with httpx.Client(mounts=proxy_mounts, verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs)
else:
with httpx.Client(verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs)
if response.status_code not in STATUS_FORCELIST:
return response
else:
logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list")
except httpx.RequestError as e:
logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}")
if max_retries == 0:
raise
retries += 1
if retries <= max_retries:
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("GET", url, max_retries=max_retries, **kwargs)
def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("POST", url, max_retries=max_retries, **kwargs)
def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("PUT", url, max_retries=max_retries, **kwargs)
def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("PATCH", url, max_retries=max_retries, **kwargs)
def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("DELETE", url, max_retries=max_retries, **kwargs)
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("HEAD", url, max_retries=max_retries, **kwargs)
使用示例
response = make_request("GET", "https://example.com/api/data")
print(response.status_code)
print(response.json())
response = get("https://example.com/api/data")
print(response.status_code)
print(response.json())
关键点解释
make_request
方法是核心方法,用于发送 HTTP 请求。它支持多种 HTTP 方法(GET、POST、PUT、PATCH、DELETE、HEAD)。- 该方法实现了重试机制,当请求失败时,会根据
max_retries
参数进行重试,并使用指数退避算法增加重试间隔。 - 支持设置代理服务器(通过
SSRF_PROXY_ALL_URL
、SSRF_PROXY_HTTP_URL
和SSRF_PROXY_HTTPS_URL
配置)。 - 自动处理超时设置,包括连接超时、读取超时和写入超时。
- 如果响应状态码在
STATUS_FORCELIST
列表中(如 429、500、502、503、504),会记录警告日志并继续重试。
流程图
graph TD
A[开始] --> B[检查重定向设置]
B --> C[设置超时]
C --> D[设置 SSL 验证]
D --> E{是否配置代理}
E -->|是| F[使用代理发送请求]
E -->|否| G[发送请求]
F --> H{请求是否成功}
G --> H
H -->|是| I[返回响应]
H -->|否| J[记录警告日志]
J --> K{重试次数是否用尽}
K -->|否| L[指数退避,重试]
K -->|是| M[抛出异常]
时序图
sequenceDiagram
participant Client
participant SSRFProxy
Client->>SSRFProxy: make_request("GET", "https://example.com/api/data")
SSRFProxy->>SSRFProxy: 设置超时和 SSL 验证
SSRFProxy->>SSRFProxy: 检查代理配置
SSRFProxy->>SSRFProxy: 使用代理发送请求
SSRFProxy->>SSRFProxy: 检查响应状态码
alt 状态码不在强制列表中
SSRFProxy-->>Client: 返回响应
else
SSRFProxy->>SSRFProxy: 记录警告日志
SSRFProxy->>SSRFProxy: 检查重试次数
loop 重试
SSRFProxy->>SSRFProxy: 指数退避等待
SSRFProxy->>SSRFProxy: 重新发送请求
end
SSRFProxy-->>Client: 抛出 MaxRetriesExceededError
end
4. 位置帮助函数模块
功能概述
位置帮助函数模块提供了一系列函数,用于从 YAML 文件中获取位置映射,并根据这些映射对对象进行排序和过滤。
关键函数
def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]:
position_file_path = os.path.join(folder_path, file_name)
yaml_content = load_yaml_file(file_path=position_file_path, default_value=[])
positions = [item.strip() for item in yaml_content if item and isinstance(item, str) and item.strip()]
return {name: index for index, name in enumerate(positions)}
def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
position_map = get_position_map(folder_path, file_name=file_name)
return pin_position_map(
position_map,
pin_list=dify_config.POSITION_TOOL_PINS_LIST,
)
def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
position_map = get_position_map(folder_path, file_name=file_name)
return pin_position_map(
position_map,
pin_list=dify_config.POSITION_PROVIDER_PINS_LIST,
)
def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]:
positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x])
position_map = {name: idx for idx, name in enumerate(pin_list)}
start_idx = len(position_map)
for name in positions:
if name not in position_map:
position_map[name] = start_idx
start_idx += 1
return position_map
def is_filtered(
include_set: set[str],
exclude_set: set[str],
data: Any,
name_func: Callable[[Any], str],
) -> bool:
if not data:
return False
if not include_set and not exclude_set:
return False
name = name_func(data)
if name in exclude_set: # exclude_set 优先级更高
return True
if include_set and name not in include_set: # 如果 include_set 不为空,则仅包含其中的项
return True
return False
def sort_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> list[Any]:
if not position_map or not data:
return data
return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf")))
def sort_to_dict_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> OrderedDict[str, Any]:
sorted_items = sort_by_position_map(position_map, data, name_func)
return OrderedDict([(name_func(item), item) for item in sorted_items])
使用示例
# 获取工具位置映射
tool_position_map = get_tool_position_map("/path/to/tools/folder")
print(tool_position_map)
# 获取提供商位置映射
provider_position_map = get_provider_position_map("/path/to/providers/folder")
print(provider_position_map)
# 排序工具列表
tools = [{"name": "tool1"}, {"name": "tool2"}, {"name": "tool3"}]
sorted_tools = sort_by_position_map(tool_position_map, tools, lambda x: x["name"])
print([tool["name"] for tool in sorted_tools])
# 排序并转换为有序字典
sorted_tools_dict = sort_to_dict_by_position_map(tool_position_map, tools, lambda x: x["name"])
print(sorted_tools_dict)
关键点解释
get_position_map
函数从指定文件夹中的 YAML 文件加载位置列表,并将其转换为字典,键为名称,值为索引。get_tool_position_map
和get_provider_position_map
函数分别获取工具和提供商的位置映射,并应用固定列表(pin_list
)。pin_position_map
函数将固定列表中的项放置在位置映射的开头。is_filtered
函数用于根据包含集和排除集对对象进行过滤。sort_by_position_map
函数根据位置映射对对象列表进行排序。sort_to_dict_by_position_map
函数将排序后的对象列表转换为有序字典。
类图
classDiagram
class PositionHelper {
+get_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]
+get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]
+get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]
+pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]
+is_filtered(include_set: set[str], exclude_set: set[str], data: Any, name_func: Callable[[Any], str]) -> bool
+sort_by_position_map(position_map: dict[str, int], data: list[Any], name_func: Callable[[Any], str]) -> list[Any]
+sort_to_dict_by_position_map(position_map: dict[str, int], data: list[Any], name_func: Callable[[Any], str]) -> OrderedDict[str, Any]
}
流程图
graph TD
A[开始] --> B[加载 YAML 文件]
B --> C[解析 YAML 内容]
C --> D[生成位置映射]
D --> E[应用固定列表]
E --> F[返回位置映射]
5. 模块导入帮助函数模块
功能概述
模块导入帮助函数模块提供了从源文件导入模块的功能,并支持动态加载子类。
关键函数
def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
try:
existed_spec = importlib.util.find_spec(module_name)
if existed_spec:
spec = existed_spec
if not spec.loader:
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
else:
spec = importlib.util.spec_from_file_location(module_name, py_file_path)
if not spec or not spec.loader:
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
if use_lazy_loader:
spec.loader = importlib.util.LazyLoader(spec.loader)
module = importlib.util.module_from_spec(spec)
if not existed_spec:
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
except Exception as e:
logging.exception(f"Failed to load module {module_name} from script file '{py_file_path!r}'")
raise e
def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]:
classes = [
x for _, x in vars(mod).items() if isinstance(x, type) and x != parent_type and issubclass(x, parent_type)
]
return classes
def load_single_subclass_from_source(
*, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
) -> type:
module = import_module_from_source(
module_name=module_name, py_file_path=script_path, use_lazy_loader=use_lazy_loader
)
subclasses = get_subclasses_from_module(module, parent_type)
match len(subclasses):
case 1:
return subclasses[0]
case 0:
raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path!r}")
case _:
raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path!r}")
使用示例
# 从源文件导入模块
module = import_module_from_source(module_name="my_module", py_file_path="/path/to/my_module.py")
# 获取模块中的子类
subclasses = get_subclasses_from_module(module, ParentClass)
print(subclasses)
# 加载单个子类
subclass = load_single_subclass_from_source(
module_name="my_module",
script_path="/path/to/my_module.py",
parent_type=ParentClass
)
print(subclass)
关键点解释
import_module_from_source
函数从指定的 Python 文件导入模块。它可以使用延迟加载器(use_lazy_loader
)来延迟模块的加载。get_subclasses_from_module
函数从模块中获取指定父类的所有子类。load_single_subclass_from_source
函数从源文件中加载单个子类。如果找不到子类或找到多个子类,会抛出异常。
类图
classDiagram
class ModuleImportHelper {
+import_module_from_source(module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType
+get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
+load_single_subclass_from_source(module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False) -> type
}
流程图
graph TD
A[开始] --> B[检查模块是否已存在]
B -->|是| C[使用现有模块规范]
B -->|否| D[从文件位置创建模块规范]
C --> E[检查加载器是否存在]
D --> E
E -->|加载器不存在| F[抛出异常]
E -->|加载器存在| G[创建模块]
G --> H[执行模块]
H --> I[返回模块]
6. 内容审核模块
功能概述
内容审核模块用于检查文本内容是否符合审核标准。它利用 OpenAI 的审核模型对文本进行分析,并返回审核结果。
关键函数
def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
moderation_config = hosting_configuration.moderation_config
openai_provider_name = f"{DEFAULT_PLUGIN_ID}/openai/openai"
if (
moderation_config
and moderation_config.enabled is True
and openai_provider_name in hosting_configuration.provider_map
and hosting_configuration.provider_map[openai_provider_name].enabled is True
):
using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type
provider_name = model_config.provider
if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers:
hosting_openai_config = hosting_configuration.provider_map[openai_provider_name]
if hosting_openai_config.credentials is None:
return False
length = 2000
text_chunks = [text[i : i + length] for i in range(0, len(text), length)]
if len(text_chunks) == 0:
return True
text_chunk = random.choice(text_chunks)
try:
model_provider_factory = ModelProviderFactory(tenant_id)
model_type_instance = model_provider_factory.get_model_type_instance(
provider=openai_provider_name, model_type=ModelType.MODERATION
)
model_type_instance = cast(ModerationModel, model_type_instance)
moderation_result = model_type_instance.invoke(
model="omni-moderation-latest", credentials=hosting_openai_config.credentials, text=text_chunk
)
if moderation_result is True:
return True
except Exception:
logger.exception(f"Fails to check moderation, provider_name: {provider_name}")
raise InvokeBadRequestError("Rate limit exceeded, please try again later.")
return False
使用示例
# 检查文本内容是否符合审核标准
is_moderated = check_moderation(
tenant_id="tenant_1",
model_config=ModelConfigWithCredentialsEntity(
provider_model_bundle=ModelProviderModelBundle(
configuration=ModelProviderModelConfiguration(
using_provider_type=ProviderType.SYSTEM
)
),
provider="provider_1"
),
text="This is a sample text to check moderation."
)
print(is_moderated)
关键点解释
check_moderation
函数检查审核配置是否启用,并确保 OpenAI 提供商配置正确。- 它将文本分块(每块 2000 个字符),并随机选择一块进行审核。
- 使用
ModelProviderFactory
创建审核模型实例,并调用其invoke
方法进行审核。 - 如果审核结果为
True
,则返回True
,表示内容符合审核标准。 - 如果审核过程中发生异常,会记录日志并抛出
InvokeBadRequestError
。
流程图
graph TD
A[开始] --> B[检查审核配置]
B -->|审核配置未启用| C[返回 False]
B -->|审核配置启用| D[检查 OpenAI 提供商配置]
D -->|配置无效| C
D --> E[检查提供商类型]
E -->|非系统提供商| C
E --> F[检查提供商是否在审核列表中]
F -->|不在列表中| C
F --> G[获取 OpenAI 配置]
G -->|凭证不存在| C
G --> H[分块文本]
H --> I[选择随机文本块]
I --> J[创建审核模型实例]
J --> K[调用审核模型]
K --> L{审核结果}
L -->|True| M[返回 True]
L -->|False| C
K -->|异常| N[记录日志并抛出异常]
7. LRU 缓存模块
功能概述
LRU 缓存模块实现了一个简单的 LRU(Least Recently Used)缓存机制,用于存储和管理缓存项。
关键类
class LRUCache:
def __init__(self, capacity: int):
self.cache: OrderedDict[Any, Any] = OrderedDict()
self.capacity = capacity
def get(self, key: Any) -> Any:
if key not in self.cache:
return None
else:
self.cache.move_to_end(key) # 将键移动到有序字典的末尾
return self.cache[key]
def put(self, key: Any, value: Any) -> None:
if key in self.cache:
self.cache.move_to_end(key)
self.cache[key] = value
if len(self.cache) > self.capacity:
self.cache.popitem(last=False) # 弹出第一个项
使用示例
# 创建 LRU 缓存实例
lru_cache = LRUCache(capacity=3)
# 添加缓存项
lru_cache.put("key1", "value1")
lru_cache.put("key2", "value2")
lru_cache.put("key3", "value3")
# 获取缓存项
print(lru_cache.get("key1")) # 输出: value1
# 添加新缓存项(超出容量)
lru_cache.put("key4", "value4")
print(lru_cache.get("key2")) # 输出: None(已被淘汰)
关键点解释
LRUCache
类使用OrderedDict
来存储缓存项,以便跟踪最近使用的项。get
方法检索缓存项,并将其移动到有序字典的末尾(表示最近使用)。put
方法添加或更新缓存项,并将其移动到有序字典的末尾。如果缓存超出容量,会从开头弹出最久未使用的项。
类图
classDiagram
class LRUCache {
+cache: OrderedDict[Any, Any]
+capacity: int
+__init__(capacity: int) -> None
+get(key: Any) -> Any
+put(key: Any, value: Any) -> None
}
流程图
graph TD
A[开始] --> B[检查键是否存在]
B -->|存在| C[移动到末尾]
C --> D[返回值]
B -->|不存在| E[返回 None]
A --> F[添加键值对]
F --> G[检查键是否存在]
G -->|存在| H[移动到末尾]
G -->|不存在| I[添加键值对]
I --> J[检查容量]
J -->|超过容量| K[弹出第一个项]
J -->|未超过容量| L[完成]
8. 加密解密模块
功能概述
加密解密模块提供了对令牌的加密和解密功能,确保数据的安全性。
关键函数
def obfuscated_token(token: str):
if not token:
return token
if len(token) <= 8:
return "*" * 20
return token[:6] + "*" * 12 + token[-2:]
def encrypt_token(tenant_id: str, token: str):
from models.account import Tenant
from models.engine import db
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
raise ValueError(f"Tenant with id {tenant_id} not found")
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(tenant_id: str, token: str):
return rsa.decrypt(base64.b64decode(token), tenant_id)
def batch_decrypt_token(tenant_id: str, tokens: list[str]):
rsa_key, cipher_rsa = rsa.get_decrypt_decoding(tenant_id)
return [rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa) for token in tokens]
def get_decrypt_decoding(tenant_id: str):
return rsa.get_decrypt_decoding(tenant_id)
def decrypt_token_with_decoding(token: str, rsa_key, cipher_rsa):
return rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa)
使用示例
# 加密令牌
encrypted_token = encrypt_token("tenant_1", "plain_text_token")
print(encrypted_token)
# 解密令牌
decrypted_token = decrypt_token("tenant_1", encrypted_token)
print(decrypted_token)
# 批量解密令牌
decrypted_tokens = batch_decrypt_token("tenant_1", [encrypted_token, encrypted_token])
print(decrypted_tokens)
关键点解释
obfuscated_token
函数对令牌进行混淆处理,隐藏大部分内容。encrypt_token
函数使用 RSA 加密算法对令牌进行加密,并将其编码为 Base64。decrypt_token
函数使用 RSA 解密算法对令牌进行解密。batch_decrypt_token
函数批量解密令牌列表。get_decrypt_decoding
和decrypt_token_with_decoding
函数提供了解密的辅助功能。
流程图
graph TD
A[开始] --> B{令牌长度 <= 8}
B -->|是| C[返回掩码字符串]
B -->|否| D[返回部分掩码令牌]
A --> E[加密令牌]
E --> F[获取租户]
F -->|租户不存在| G[抛出异常]
F -->|租户存在| H[加密令牌]
H --> I[Base64 编码]
I --> J[返回加密令牌]
A --> K[解密令牌]
K --> L[Base64 解码]
L --> M[解密令牌]
M --> N[返回解密令牌]
9. 下载与大小限制模块
功能概述
下载与大小限制模块用于下载文件,并限制下载大小,以防止下载过大的文件。
关键函数
from core.helper import ssrf_proxy
def download_with_size_limit(url, max_download_size: int, **kwargs):
response = ssrf_proxy.get(url, follow_redirects=True, **kwargs)
if response.status_code == 404:
raise ValueError("file not found")
total_size = 0
chunks = []
for chunk in response.iter_bytes():
total_size += len(chunk)
if total_size > max_download_size:
raise ValueError("Max file size reached")
chunks.append(chunk)
content = b"".join(chunks)
return content
使用示例
# 下载文件并限制大小
content = download_with_size_limit("https://example.com/file.zip", 1024 * 1024) # 限制为 1MB
print(len(content))
关键点解释
download_with_size_limit
函数使用 SSRF 代理模块的get
方法下载文件。- 它检查响应状态码,如果为 404,则抛出异常。
- 通过迭代响应的字节块,逐步下载文件内容,同时跟踪总下载大小。
- 如果下载大小超过指定限制,会抛出
ValueError
。 - 最后,将所有字节块合并成一个字节对象并返回。
流程图
graph TD
A[开始] --> B[发送 GET 请求]
B --> C{状态码 == 404}
C -->|是| D[抛出异常]
C -->|否| E[初始化变量]
E --> F[迭代响应字节块]
F --> G[累加块大小]
G --> H{总大小 > 限制}
H -->|是| I[抛出异常]
H -->|否| J[添加块到列表]
J --> F
F --> K[合并块]
K --> L[返回内容]
四、总结
Dify 的 helper 模块提供了一系列功能强大的工具,涵盖了从 URL 签名、缓存操作、安全代理请求、位置排序、模块导入、内容审核、LRU 缓存、加密解密到文件下载等多个方面。这些模块的设计旨在简化开发过程,提高代码的可维护性和可扩展性。
通过本文的详细解读,我们希望读者能够深入理解每个模块的功能和实现原理,并能够灵活运用这些工具来构建高效、安全的 AI 应用。无论是初学者还是有经验的开发者,都可以从这些模块中获益,加速项目的开发和部署。
在未来的发展中,Dify 的 helper 模块预计会不断扩展和优化,以满足不断变化的技术需求和应用场景。开发者可以持续关注这些模块的更新,以便充分利用其新功能和改进。