我们一起来从零开始设计和实现一个 Python 下的 DAG(有向无环图),并结合 GitHub 上常见的代码模式进行优化。
第一步:理解 DAG 的基本概念和需求
首先,我们需要明确 DAG 的核心概念:
- 节点(Node): 代表任务或者操作。
- 有向边(Directed Edge): 表示节点之间的依赖关系,从一个节点指向另一个节点,意味着前者必须在后者之前完成。
- 无环(Acyclic): 图中不存在从某个节点出发,经过一系列边最终回到该节点自身的路径。这是 DAG 的关键特性。
我们的目标是实现一个 Python 类,能够:
- 添加节点: 允许用户向 DAG 中添加任务节点。
- 添加边: 允许用户定义节点之间的依赖关系。
- 执行 DAG: 按照依赖关系执行节点代表的任务。
- 检测环: 在添加边时或执行前检测是否存在环。
第二步:初步设计 - 核心数据结构
在 Python 中,表示图最常用的方式是使用邻接表。对于 DAG,我们可以使用字典来实现邻接表,其中:
- 键(Key): 代表一个节点。
- 值(Value): 是一个列表,包含该节点的所有后继节点(依赖于该节点的节点)。
同时,为了方便反向查找依赖关系,我们也可以维护一个反向邻接表:
- 键(Key): 代表一个节点。
- 值(Value): 是一个列表,包含所有指向该节点的节点(该节点依赖的节点)。
此外,为了存储节点代表的任务(函数或其他可执行对象),我们可以使用另一个字典:
- 键(Key): 代表一个节点。
- 值(Value): 代表与该节点关联的任务。
初步代码框架:
class DAG:
def __init__(self):
self._graph = {} # 邻接表:节点 -> [后继节点]
self._reverse_graph = {} # 反向邻接表:节点 -> [前驱节点]
self._tasks = {} # 节点 -> 任务
def add_node(self, node, task=None):
if node not in self._graph:
self._graph[node] = []
self._reverse_graph[node] = []
self._tasks[node] = task
def add_edge(self, from_node, to_node):
if from_node not in self._graph or to_node not in self._graph:
raise ValueError("节点不存在")
if to_node not in self._graph[from_node]:
self._graph[from_node].append(to_node)
self._reverse_graph[to_node].append(from_node)
# 思考:这里是否需要进行环检测?
第三步:完善添加边和环检测
在 add_edge 方法中,我们需要考虑环检测。一个常见的环检测方法是使用深度优先搜索(DFS)。
环检测思路:
- 维护三个集合:
visited: 已经访问过的节点。visiting: 当前正在访问的节点。
- 从每个节点开始进行 DFS。
- 在 DFS 过程中,如果遇到一个节点已经在
visiting集合中,则说明存在环。 - 当一个节点的所有后继节点都访问完毕后,将其从
visiting集合中移除,并加入visited集合。
改进后的 add_edge 方法:
class DAG:
# ... (之前的代码)
def add_edge(self, from_node, to_node):
if from_node not in self._graph or to_node not in self._graph:
raise ValueError("节点不存在")
if to_node not in self._graph[from_node]:
self._graph[from_node].append(to_node)
self._reverse_graph[to_node].append(from_node)
if self._has_cycle():
# 如果添加边导致环,则撤销操作
self._graph[from_node].remove(to_node)
self._reverse_graph[to_node].remove(from_node)
raise ValueError("添加边会导致环")
def _has_cycle(self):
visited = set()
visiting = set()
def _dfs(node):
visiting.add(node)
for neighbor in self._graph.get(node, []):
if neighbor in visiting:
return True
if neighbor not in visited:
if _dfs(neighbor):
return True
visiting.remove(node)
visited.add(node)
return False
for node in self._graph:
if node not in visited:
if _dfs(node):
return True
return False
思考与优化 1:
- 环检测的时机: 我们选择在每次添加边之后进行环检测,这可以尽早发现问题。另一种策略是在执行 DAG 前进行一次性检测。选择哪种方式取决于对性能和错误反馈的需求。频繁检测会增加开销,但能提供更即时的错误信息。
- 环检测算法: DFS 是一种常见的环检测方法,但对于大型图,可能需要考虑更高效的算法。
第四步:实现 DAG 的执行
执行 DAG 的核心是按照依赖关系排序节点,这可以通过拓扑排序算法实现。
拓扑排序思路:
- 计算每个节点的入度(指向该节点的边的数量)。
- 将所有入度为 0 的节点放入一个队列。
- 当队列不为空时:
- 从队列中取出一个节点。
- 执行该节点对应的任务。
- 将该节点的所有后继节点的入度减 1。
- 如果某个后继节点的入度变为 0,则将其加入队列。
- 如果所有节点都被处理,则执行成功。否则,图中存在环(这应该在添加边时就被检测出来)。
实现 execute 方法:
from collections import deque
class DAG:
# ... (之前的代码)
def execute(self):
in_degree = {node: len(self._reverse_graph[node]) for node in self._graph}
queue = deque([node for node in self._graph if in_degree[node] == 0])
executed_nodes = []
while queue:
node = queue.popleft()
print(f"执行节点: {node}")
task = self._tasks.get(node)
if task:
task() # 执行任务
executed_nodes.append(node)
for neighbor in self._graph.get(node, []):
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
if len(executed_nodes) != len(self._graph):
raise RuntimeError("图中存在环,无法完成拓扑排序") # 理论上不会发生,因为添加边时已检测
思考与优化 2:
- 任务执行: 目前的任务执行是简单的函数调用。在实际应用中,任务可能需要传递参数、处理返回值、进行错误处理等。
- 并行执行: 对于相互独立的节点,可以并行执行以提高效率。可以使用
threading或asyncio模块来实现。 - 执行顺序: 拓扑排序保证了依赖关系的正确性,但对于没有依赖关系的节点,执行顺序可能不确定。如果需要特定的执行顺序,可以进行额外的排序或优先级控制。
第五步:添加更灵活的任务定义和执行
目前,我们假设任务是简单的无参函数。为了更灵活地处理各种任务,我们可以允许用户在添加节点时传递任意可调用对象,并允许在执行时传递参数。
改进后的 add_node 和 execute 方法:
class DAG:
# ... (之前的代码)
def add_node(self, node, task=None, *args, **kwargs):
if node not in self._graph:
self._graph[node] = []
self._reverse_graph[node] = []
self._tasks[node] = (task, args, kwargs) # 存储任务和参数
def execute(self):
in_degree = {node: len(self._reverse_graph[node]) for node in self._graph}
queue = deque([node for node in self._graph if in_degree[node] == 0])
executed_nodes = {} # 存储执行结果
while queue:
node = queue.popleft()
task_info = self._tasks.get(node)
if task_info:
task, args, kwargs = task_info
print(f"执行节点: {node}, 任务: {task.__name__ if callable(task) else task}")
try:
result = task(*args, **kwargs) # 执行任务并获取结果
executed_nodes[node] = result
except Exception as e:
print(f"节点 {node} 执行失败: {e}")
raise # 可以选择抛出异常或继续执行
for neighbor in self._graph.get(node, []):
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
if len(executed_nodes) != len(self._graph):
raise RuntimeError("图中存在环或部分节点未执行")
return executed_nodes # 返回执行结果
思考与优化 3:
- 任务参数传递: 允许在添加节点时传递参数,使得任务可以接收特定的输入。
- 任务执行结果: 存储每个节点的执行结果,方便后续节点使用。
- 错误处理: 在任务执行过程中添加了
try-except块,可以捕获异常并进行处理。可以根据需求选择抛出异常、记录日志或跳过该节点。 - 依赖注入: 如果任务之间需要传递数据,可以通过执行结果来实现简单的依赖注入。例如,一个节点的输出可以作为另一个节点的输入。
第六步:借鉴 GitHub 常见代码模式进行优化
在 GitHub 上,常见的代码模式可以帮助我们提高代码的可读性、可维护性和性能。
- 使用装饰器: 可以使用装饰器来简化任务的添加和定义。
- 上下文管理器: 可以使用上下文管理器来管理资源的分配和释放。
- 生成器: 可以使用生成器来处理大型数据集或异步操作。
- 类型提示: 使用类型提示可以提高代码的可读性和可维护性,并帮助静态类型检查工具发现错误。
- 单元测试: 编写单元测试来验证 DAG 的各个功能是否正常工作。
示例:使用装饰器简化任务添加
class DAG:
# ... (之前的代码)
def task(self, node, *args, **kwargs):
def decorator(func):
self.add_node(node, func, *args, **kwargs)
return func
return decorator
# 使用装饰器添加任务
dag = DAG()
@dag.task("task_a")
def task_a():
print("执行任务 A")
return "result_a"
@dag.task("task_b", input_value=10)
def task_b(input_value):
print(f"执行任务 B,输入: {input_value}")
return input_value * 2
dag.add_edge("task_a", "task_b")
dag.execute()
思考与优化 4:
- 模块化设计: 将 DAG 的不同功能模块化,例如节点管理、边管理、执行引擎、环检测等,提高代码的可维护性。
- 配置化: 允许通过配置文件或外部数据源来定义 DAG 的结构和任务,提高灵活性。
- 可视化: 提供可视化 DAG 结构的功能,方便用户理解和调试。可以使用
graphviz等库。 - 异步执行: 使用
asyncio或concurrent.futures实现异步并行执行,提高性能。
第七步:进一步的思考和扩展
- 错误处理策略: 更精细的错误处理,例如重试机制、回滚操作、依赖失败处理等。
- 数据传递和共享: 更复杂的数据传递机制,例如使用消息队列或共享内存。
- 状态管理: 跟踪 DAG 的执行状态,例如节点的状态(等待、运行中、已完成、失败)。
- 监控和日志: 集成监控和日志功能,方便观察 DAG 的运行情况。
- 与其他工具集成: 例如与工作流引擎(如 Airflow、Luigi)集成。
总结与最终代码(包含一些优化):
from collections import deque
from functools import wraps
class DAG:
def __init__(self):
self._graph = {}
self._reverse_graph = {}
self._tasks = {}
def add_node(self, node, task=None, *args, **kwargs):
if node not in self._graph:
self._graph[node] = []
self._reverse_graph[node] = []
self._tasks[node] = (task, args, kwargs)
def add_edge(self, from_node, to_node):
if from_node not in self._graph or to_node not in self._graph:
raise ValueError("节点不存在")
if to_node not in self._graph[from_node]:
self._graph[from_node].append(to_node)
self._reverse_graph[to_node].append(from_node)
if self._has_cycle():
self._graph[from_node].remove(to_node)
self._reverse_graph[to_node].remove(from_node)
raise ValueError("添加边会导致环")
def _has_cycle(self):
visited = set()
visiting = set()
def _dfs(node):
visiting.add(node)
for neighbor in self._graph.get(node, []):
if neighbor in visiting:
return True
if neighbor not in visited:
if _dfs(neighbor):
return True
visiting.remove(node)
visited.add(node)
return False
for node in self._graph:
if node not in visited:
if _dfs(node):
return True
return False
def task(self, node, *args, **kwargs):
def decorator(func):
self.add_node(node, func, *args, **kwargs)
@wraps(func)
def wrapper(*_args, **_kwargs):
return func(*_args, **_kwargs)
return wrapper
return decorator
def execute(self):
in_degree = {node: len(self._reverse_graph[node]) for node in self._graph}
queue = deque([node for node in self._graph if in_degree[node] == 0])
executed_nodes = {}
while queue:
node = queue.popleft()
task_info = self._tasks.get(node)
if task_info:
task, args, kwargs = task_info
print(f"执行节点: {node}, 任务: {task.__name__ if callable(task) else task}")
try:
result = task(*args, **kwargs)
executed_nodes[node] = result
except Exception as e:
print(f"节点 {node} 执行失败: {e}")
raise
for neighbor in self._graph.get(node, []):
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
if len(executed_nodes) != len(self._graph):
raise RuntimeError("图中存在环或部分节点未执行")
return executed_nodes
# 示例用法
dag = DAG()
@dag.task("start")
def start_task():
print("开始任务")
return 10
@dag.task("process", multiplier=2)
def process_task(value, multiplier):
print(f"处理任务,值: {value},乘数: {multiplier}")
return value * multiplier
@dag.task("end")
def end_task(value):
print(f"结束任务,最终值: {value}")
dag.add_edge("start", "process")
dag.add_edge("process", "end")
results = dag.execute()
print("执行结果:", results)
通过这个过程,我们从零开始设计并实现了一个 Python 下的 DAG,并逐步进行了优化和扩展