python下的onnx模型常用结构输出

160 阅读1分钟
def print_onnx_info(onnx_model):
    print(" --- ONNX Model Information ---")
    
    # Model Properties
    print("[MODEL PROPERTIES]")
    print(f"Ir Version: {onnx_model.ir_version}")
    print(f"Producer Name: {onnx_model.producer_name}")
    print(f"Producer Version: {onnx_model.producer_version}")
    print(f"Domain: {onnx_model.domain}")
    print(f"Model Version: {onnx_model.model_version}")
    print(f"Doc String: {onnx_model.doc_string}")
    
    print("\n[INPUTS]")
    for input_tensor in onnx_model.graph.input:
        input_name = input_tensor.name
        shape = [dim.dim_value if dim.dim_value > 0 else 'dynamic' for dim in input_tensor.type.tensor_type.shape.dim]
        dtype = input_tensor.type.tensor_type.elem_type
        print(f"Tensor Name: <{input_name}>   Shape: {shape}   Dtype: {onnx.helper.mapping.TENSOR_TYPE_TO_NP_TYPE[dtype]}")
    
    print("\n[OUTPUTS]")
    for output_tensor in onnx_model.graph.output:
        output_name = output_tensor.name
        shape = [dim.dim_value if dim.dim_value > 0 else 'dynamic' for dim in output_tensor.type.tensor_type.shape.dim]
        dtype = output_tensor.type.tensor_type.elem_type
        print(f"Tensor Name: <{output_name}>   Shape: {shape}   Dtype: {onnx.helper.mapping.TENSOR_TYPE_TO_NP_TYPE[dtype]}")
    
    print("\n[NODES]")
    for node in onnx_model.graph.node:
        print(f"Node name: <{node.name}>, Op Type: {node.op_type}")
        print(f"  Inputs: {node.input}")
        print(f"  Outputs: {node.output}")
        print("")

    print("\n[INITIALIZERS]")
    for initializer in onnx_model.graph.initializer:
        print(f"Initializer Name: <{initializer.name}>   Shape: {initializer.dims}   Dtype: {onnx.helper.mapping.TENSOR_TYPE_TO_NP_TYPE[initializer.data_type]}")

# Example usage
import onnx
onnx_model = onnx.load('model.onnx')
print_onnx_info(onnx_model)