从零开始设计和实现一个 Python 下的 DAG(有向无环图)

700 阅读9分钟

我们一起来从零开始设计和实现一个 Python 下的 DAG(有向无环图),并结合 GitHub 上常见的代码模式进行优化。

第一步:理解 DAG 的基本概念和需求

首先,我们需要明确 DAG 的核心概念:

  • 节点(Node): 代表任务或者操作。
  • 有向边(Directed Edge): 表示节点之间的依赖关系,从一个节点指向另一个节点,意味着前者必须在后者之前完成。
  • 无环(Acyclic): 图中不存在从某个节点出发,经过一系列边最终回到该节点自身的路径。这是 DAG 的关键特性。

我们的目标是实现一个 Python 类,能够:

  1. 添加节点: 允许用户向 DAG 中添加任务节点。
  2. 添加边: 允许用户定义节点之间的依赖关系。
  3. 执行 DAG: 按照依赖关系执行节点代表的任务。
  4. 检测环: 在添加边时或执行前检测是否存在环。

第二步:初步设计 - 核心数据结构

在 Python 中,表示图最常用的方式是使用邻接表。对于 DAG,我们可以使用字典来实现邻接表,其中:

  • 键(Key): 代表一个节点。
  • 值(Value): 是一个列表,包含该节点的所有后继节点(依赖于该节点的节点)。

同时,为了方便反向查找依赖关系,我们也可以维护一个反向邻接表:

  • 键(Key): 代表一个节点。
  • 值(Value): 是一个列表,包含所有指向该节点的节点(该节点依赖的节点)。

此外,为了存储节点代表的任务(函数或其他可执行对象),我们可以使用另一个字典:

  • 键(Key): 代表一个节点。
  • 值(Value): 代表与该节点关联的任务。

初步代码框架:

class DAG:
    def __init__(self):
        self._graph = {}  # 邻接表:节点 -> [后继节点]
        self._reverse_graph = {}  # 反向邻接表:节点 -> [前驱节点]
        self._tasks = {}  # 节点 -> 任务

    def add_node(self, node, task=None):
        if node not in self._graph:
            self._graph[node] = []
            self._reverse_graph[node] = []
            self._tasks[node] = task

    def add_edge(self, from_node, to_node):
        if from_node not in self._graph or to_node not in self._graph:
            raise ValueError("节点不存在")
        if to_node not in self._graph[from_node]:
            self._graph[from_node].append(to_node)
            self._reverse_graph[to_node].append(from_node)
        # 思考:这里是否需要进行环检测?

第三步:完善添加边和环检测

add_edge 方法中,我们需要考虑环检测。一个常见的环检测方法是使用深度优先搜索(DFS)。

环检测思路:

  1. 维护三个集合:
    • visited: 已经访问过的节点。
    • visiting: 当前正在访问的节点。
  2. 从每个节点开始进行 DFS。
  3. 在 DFS 过程中,如果遇到一个节点已经在 visiting 集合中,则说明存在环。
  4. 当一个节点的所有后继节点都访问完毕后,将其从 visiting 集合中移除,并加入 visited 集合。

改进后的 add_edge 方法:

class DAG:
    # ... (之前的代码)

    def add_edge(self, from_node, to_node):
        if from_node not in self._graph or to_node not in self._graph:
            raise ValueError("节点不存在")
        if to_node not in self._graph[from_node]:
            self._graph[from_node].append(to_node)
            self._reverse_graph[to_node].append(from_node)
            if self._has_cycle():
                # 如果添加边导致环,则撤销操作
                self._graph[from_node].remove(to_node)
                self._reverse_graph[to_node].remove(from_node)
                raise ValueError("添加边会导致环")

    def _has_cycle(self):
        visited = set()
        visiting = set()

        def _dfs(node):
            visiting.add(node)
            for neighbor in self._graph.get(node, []):
                if neighbor in visiting:
                    return True
                if neighbor not in visited:
                    if _dfs(neighbor):
                        return True
            visiting.remove(node)
            visited.add(node)
            return False

        for node in self._graph:
            if node not in visited:
                if _dfs(node):
                    return True
        return False

思考与优化 1:

  • 环检测的时机: 我们选择在每次添加边之后进行环检测,这可以尽早发现问题。另一种策略是在执行 DAG 前进行一次性检测。选择哪种方式取决于对性能和错误反馈的需求。频繁检测会增加开销,但能提供更即时的错误信息。
  • 环检测算法: DFS 是一种常见的环检测方法,但对于大型图,可能需要考虑更高效的算法。

第四步:实现 DAG 的执行

执行 DAG 的核心是按照依赖关系排序节点,这可以通过拓扑排序算法实现。

拓扑排序思路:

  1. 计算每个节点的入度(指向该节点的边的数量)。
  2. 将所有入度为 0 的节点放入一个队列。
  3. 当队列不为空时:
    • 从队列中取出一个节点。
    • 执行该节点对应的任务。
    • 将该节点的所有后继节点的入度减 1。
    • 如果某个后继节点的入度变为 0,则将其加入队列。
  4. 如果所有节点都被处理,则执行成功。否则,图中存在环(这应该在添加边时就被检测出来)。

实现 execute 方法:

from collections import deque

class DAG:
    # ... (之前的代码)

    def execute(self):
        in_degree = {node: len(self._reverse_graph[node]) for node in self._graph}
        queue = deque([node for node in self._graph if in_degree[node] == 0])
        executed_nodes = []

        while queue:
            node = queue.popleft()
            print(f"执行节点: {node}")
            task = self._tasks.get(node)
            if task:
                task()  # 执行任务
            executed_nodes.append(node)

            for neighbor in self._graph.get(node, []):
                in_degree[neighbor] -= 1
                if in_degree[neighbor] == 0:
                    queue.append(neighbor)

        if len(executed_nodes) != len(self._graph):
            raise RuntimeError("图中存在环,无法完成拓扑排序") # 理论上不会发生,因为添加边时已检测

思考与优化 2:

  • 任务执行: 目前的任务执行是简单的函数调用。在实际应用中,任务可能需要传递参数、处理返回值、进行错误处理等。
  • 并行执行: 对于相互独立的节点,可以并行执行以提高效率。可以使用 threadingasyncio 模块来实现。
  • 执行顺序: 拓扑排序保证了依赖关系的正确性,但对于没有依赖关系的节点,执行顺序可能不确定。如果需要特定的执行顺序,可以进行额外的排序或优先级控制。

第五步:添加更灵活的任务定义和执行

目前,我们假设任务是简单的无参函数。为了更灵活地处理各种任务,我们可以允许用户在添加节点时传递任意可调用对象,并允许在执行时传递参数。

改进后的 add_nodeexecute 方法:

class DAG:
    # ... (之前的代码)

    def add_node(self, node, task=None, *args, **kwargs):
        if node not in self._graph:
            self._graph[node] = []
            self._reverse_graph[node] = []
            self._tasks[node] = (task, args, kwargs)  # 存储任务和参数

    def execute(self):
        in_degree = {node: len(self._reverse_graph[node]) for node in self._graph}
        queue = deque([node for node in self._graph if in_degree[node] == 0])
        executed_nodes = {} # 存储执行结果

        while queue:
            node = queue.popleft()
            task_info = self._tasks.get(node)
            if task_info:
                task, args, kwargs = task_info
                print(f"执行节点: {node}, 任务: {task.__name__ if callable(task) else task}")
                try:
                    result = task(*args, **kwargs) # 执行任务并获取结果
                    executed_nodes[node] = result
                except Exception as e:
                    print(f"节点 {node} 执行失败: {e}")
                    raise  # 可以选择抛出异常或继续执行

            for neighbor in self._graph.get(node, []):
                in_degree[neighbor] -= 1
                if in_degree[neighbor] == 0:
                    queue.append(neighbor)

        if len(executed_nodes) != len(self._graph):
            raise RuntimeError("图中存在环或部分节点未执行")

        return executed_nodes # 返回执行结果

思考与优化 3:

  • 任务参数传递: 允许在添加节点时传递参数,使得任务可以接收特定的输入。
  • 任务执行结果: 存储每个节点的执行结果,方便后续节点使用。
  • 错误处理: 在任务执行过程中添加了 try-except 块,可以捕获异常并进行处理。可以根据需求选择抛出异常、记录日志或跳过该节点。
  • 依赖注入: 如果任务之间需要传递数据,可以通过执行结果来实现简单的依赖注入。例如,一个节点的输出可以作为另一个节点的输入。

第六步:借鉴 GitHub 常见代码模式进行优化

在 GitHub 上,常见的代码模式可以帮助我们提高代码的可读性、可维护性和性能。

  • 使用装饰器: 可以使用装饰器来简化任务的添加和定义。
  • 上下文管理器: 可以使用上下文管理器来管理资源的分配和释放。
  • 生成器: 可以使用生成器来处理大型数据集或异步操作。
  • 类型提示: 使用类型提示可以提高代码的可读性和可维护性,并帮助静态类型检查工具发现错误。
  • 单元测试: 编写单元测试来验证 DAG 的各个功能是否正常工作。

示例:使用装饰器简化任务添加

class DAG:
    # ... (之前的代码)

    def task(self, node, *args, **kwargs):
        def decorator(func):
            self.add_node(node, func, *args, **kwargs)
            return func
        return decorator

# 使用装饰器添加任务
dag = DAG()

@dag.task("task_a")
def task_a():
    print("执行任务 A")
    return "result_a"

@dag.task("task_b", input_value=10)
def task_b(input_value):
    print(f"执行任务 B,输入: {input_value}")
    return input_value * 2

dag.add_edge("task_a", "task_b")
dag.execute()

思考与优化 4:

  • 模块化设计: 将 DAG 的不同功能模块化,例如节点管理、边管理、执行引擎、环检测等,提高代码的可维护性。
  • 配置化: 允许通过配置文件或外部数据源来定义 DAG 的结构和任务,提高灵活性。
  • 可视化: 提供可视化 DAG 结构的功能,方便用户理解和调试。可以使用 graphviz 等库。
  • 异步执行: 使用 asyncioconcurrent.futures 实现异步并行执行,提高性能。

第七步:进一步的思考和扩展

  • 错误处理策略: 更精细的错误处理,例如重试机制、回滚操作、依赖失败处理等。
  • 数据传递和共享: 更复杂的数据传递机制,例如使用消息队列或共享内存。
  • 状态管理: 跟踪 DAG 的执行状态,例如节点的状态(等待、运行中、已完成、失败)。
  • 监控和日志: 集成监控和日志功能,方便观察 DAG 的运行情况。
  • 与其他工具集成: 例如与工作流引擎(如 Airflow、Luigi)集成。

总结与最终代码(包含一些优化):

from collections import deque
from functools import wraps

class DAG:
    def __init__(self):
        self._graph = {}
        self._reverse_graph = {}
        self._tasks = {}

    def add_node(self, node, task=None, *args, **kwargs):
        if node not in self._graph:
            self._graph[node] = []
            self._reverse_graph[node] = []
            self._tasks[node] = (task, args, kwargs)

    def add_edge(self, from_node, to_node):
        if from_node not in self._graph or to_node not in self._graph:
            raise ValueError("节点不存在")
        if to_node not in self._graph[from_node]:
            self._graph[from_node].append(to_node)
            self._reverse_graph[to_node].append(from_node)
            if self._has_cycle():
                self._graph[from_node].remove(to_node)
                self._reverse_graph[to_node].remove(from_node)
                raise ValueError("添加边会导致环")

    def _has_cycle(self):
        visited = set()
        visiting = set()

        def _dfs(node):
            visiting.add(node)
            for neighbor in self._graph.get(node, []):
                if neighbor in visiting:
                    return True
                if neighbor not in visited:
                    if _dfs(neighbor):
                        return True
            visiting.remove(node)
            visited.add(node)
            return False

        for node in self._graph:
            if node not in visited:
                if _dfs(node):
                    return True
        return False

    def task(self, node, *args, **kwargs):
        def decorator(func):
            self.add_node(node, func, *args, **kwargs)
            @wraps(func)
            def wrapper(*_args, **_kwargs):
                return func(*_args, **_kwargs)
            return wrapper
        return decorator

    def execute(self):
        in_degree = {node: len(self._reverse_graph[node]) for node in self._graph}
        queue = deque([node for node in self._graph if in_degree[node] == 0])
        executed_nodes = {}

        while queue:
            node = queue.popleft()
            task_info = self._tasks.get(node)
            if task_info:
                task, args, kwargs = task_info
                print(f"执行节点: {node}, 任务: {task.__name__ if callable(task) else task}")
                try:
                    result = task(*args, **kwargs)
                    executed_nodes[node] = result
                except Exception as e:
                    print(f"节点 {node} 执行失败: {e}")
                    raise

            for neighbor in self._graph.get(node, []):
                in_degree[neighbor] -= 1
                if in_degree[neighbor] == 0:
                    queue.append(neighbor)

        if len(executed_nodes) != len(self._graph):
            raise RuntimeError("图中存在环或部分节点未执行")

        return executed_nodes

# 示例用法
dag = DAG()

@dag.task("start")
def start_task():
    print("开始任务")
    return 10

@dag.task("process", multiplier=2)
def process_task(value, multiplier):
    print(f"处理任务,值: {value},乘数: {multiplier}")
    return value * multiplier

@dag.task("end")
def end_task(value):
    print(f"结束任务,最终值: {value}")

dag.add_edge("start", "process")
dag.add_edge("process", "end")

results = dag.execute()
print("执行结果:", results)

通过这个过程,我们从零开始设计并实现了一个 Python 下的 DAG,并逐步进行了优化和扩展