import onnx import queue
class Node: def init(self, name, op_type, inputs, outputs): self.name = name self.op_type = op_type self.inputs = inputs self.outputs = outputs class Graph: def init(self,model_path = None): ''' Initialize the graph object :param model: Can be a ModelProto object or a path to a model file ''' self.entry = [] self.map = {} self.model_path = model_path if model_path is None: return model = onnx.load(model_path) self.build(model) def build(self, model): ''' Build the graph from a ModelProto object :param model: A ModelProto object ''' self.entry = [x.name for x in model.graph.input] for x in model.graph.input: self.map[x.name] = Node(x.name, "Input", inputs=[], outputs=[x.name])
for node in model.graph.node: self.map[node.name] = Node(node.name, node.op_type, node.input, node.output)
# 把边作为一个单输入单输出的节点 for x in node.input: if x not in self.map: # 边还不在图中,则加入,并且它的输出是当前节点 self.map[x] = Node(x, "Edge", inputs=[], outputs=[node.name]) else: self.map[x].outputs.append(node.name)
for x in node.output: if x not in self.map: self.map[x] = Node(x, "Edge", inputs=[node.name], outputs=[]) else: self.map[x].inputs.append(node.name)
def split(self, name, upstream_model_path = None, downstream_model_path = None): ''' Split the graph from a node :param node: A node name ''' if upstream_model_path is None: upstream_model_path = "upstream.onnx" if downstream_model_path is None: downstream_model_path = "downstream.onnx" if name not in self.map: raise ValueError("node not in graph")