切分

62 阅读1分钟

import onnx import queue

class Node:     def init(self, name, op_type, inputs, outputs):         self.name = name         self.op_type = op_type         self.inputs = inputs         self.outputs = outputs class Graph:     def init(self,model_path = None):         '''         Initialize the graph object         :param model: Can be a ModelProto object or a path to a model file         '''         self.entry = []         self.map = {}         self.model_path = model_path         if model_path is None:             return         model = onnx.load(model_path)         self.build(model)     def build(self, model):         '''         Build the graph from a ModelProto object         :param model: A ModelProto object         '''         self.entry = [x.name for x in model.graph.input]         for x in model.graph.input:             self.map[x.name] = Node(x.name, "Input", inputs=[], outputs=[x.name])

        for node in model.graph.node:             self.map[node.name] = Node(node.name, node.op_type, node.input, node.output)

            # 把边作为一个单输入单输出的节点             for x in node.input:                 if x not in self.map:                     # 边还不在图中,则加入,并且它的输出是当前节点                     self.map[x] = Node(x, "Edge", inputs=[], outputs=[node.name])                 else:                     self.map[x].outputs.append(node.name)

            for x in node.output:                 if x not in self.map:                     self.map[x] = Node(x, "Edge", inputs=[node.name], outputs=[])                 else:                     self.map[x].inputs.append(node.name)

    def split(self, name, upstream_model_path = None, downstream_model_path = None):         '''         Split the graph from a node         :param node: A node name         '''         if upstream_model_path is None:             upstream_model_path = "upstream.onnx"         if downstream_model_path is None:             downstream_model_path = "downstream.onnx"         if name not in self.map:             raise ValueError("node not in graph")