python下的trt模型常用结构输出

59 阅读1分钟
import onnx
import onnx.helper as helper
import onnx.numpy_helper as numpy_helper
import numpy as np
import tensorrt as trt

def create_simple_onnx_model(onnx_file_path):
    input_tensor = helper.make_tensor_value_info('input', onnx.TensorProto.FLOAT, [1, 3, 32, 32])
    output_tensor = helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [1, 10])

    conv_weight = numpy_helper.from_array(np.random.randn(10, 3, 3, 3).astype(np.float32), name='conv_weight')
    conv = helper.make_node('Conv', inputs=['input', 'conv_weight'], outputs=['conv_out'], kernel_shape=[3, 3])
    
    relu = helper.make_node('Relu', inputs=['conv_out'], outputs=['relu_out'])
    
    flatten = helper.make_node('Flatten', inputs=['relu_out'], outputs=['flatten_out'], axis=1)

    fc_weight = numpy_helper.from_array(np.random.randn(9000, 10).astype(np.float32), name='fc_weight')
    fc = helper.make_node('Gemm', inputs=['flatten_out', 'fc_weight'], outputs=['output'])


    graph = helper.make_graph([conv, relu, flatten, fc], 'simple_model', [input_tensor], [output_tensor], [conv_weight, fc_weight])
    model = helper.make_model(graph, producer_name='simple_onnx_model')

    onnx.save(model, onnx_file_path)
    print(f"ONNX model saved to {onnx_file_path}")

def parse_onnx_to_trt(onnx_file_path):
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    parser = trt.OnnxParser(network, TRT_LOGGER)

    with open(onnx_file_path, 'rb') as model:
        if not parser.parse(model.read()):
            print('Failed to parse the ONNX file.')
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            return None

    for i in range(network.num_layers):
        layer = network.get_layer(i)
        print(f"Layer {i}: {layer.name}, Type: {layer.type}, Precision: {layer.precision}")
        for j in range(layer.num_inputs):
            input_tensor = layer.get_input(j)
            print(f"  Input {j}: {input_tensor.name}, Shape: {input_tensor.shape}, Dtype: {input_tensor.dtype}")
        for j in range(layer.num_outputs):
            output_tensor = layer.get_output(j)
            print(f"  Output {j}: {output_tensor.name}, Shape: {output_tensor.shape}, Dtype: {output_tensor.dtype}")

    return network

def load_and_inspect_engine(engine_file_path):
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
    runtime = trt.Runtime(TRT_LOGGER)

    try:
        with open(engine_file_path, 'rb') as f:
            engine = runtime.deserialize_cuda_engine(f.read())
        print("Engine loaded successfully.")
    except Exception as e:
        print(f"Failed to load the engine: {str(e)}")
        return
    
    print(" --- TensorRT Engine Information ---")
    print(f"Max Workspace Size: {engine.max_workspace_size if hasattr(engine, 'max_workspace_size') else 'N/A'} bytes")
    print(f"Device Memory Size: {engine.device_memory_size if hasattr(engine, 'device_memory_size') else 'N/A'} bytes")
    
    print("\n[INPUTS]")
    for i in range(engine.num_bindings):
        tensor_name = engine.get_tensor_name(i)
        if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
            tensor_shape = engine.get_tensor_shape(tensor_name)
            tensor_dtype = engine.get_tensor_dtype(tensor_name)
            print(f"Tensor Name: <{tensor_name}>   Shape: {tensor_shape}   Dtype: {tensor_dtype}")
    
    print("\n[OUTPUTS]")
    for i in range(engine.num_bindings):
        tensor_name = engine.get_tensor_name(i)
        if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.OUTPUT:
            tensor_shape = engine.get_tensor_shape(tensor_name)
            tensor_dtype = engine.get_tensor_dtype(tensor_name)
            print(f"Tensor Name: <{tensor_name}>   Shape: {tensor_shape}   Dtype: {tensor_dtype}")

    print("\n[LAYERS]")
    print("Layer information is not accessible directly from the deserialized engine.")

    print("\n[DETAILED LAYER INFORMATION]")
    print("Layer information cannot be accessed after engine deserialization.")

# Example usage 1
onnx_file_path = 'model.onnx'
create_simple_onnx_model(onnx_file_path)
parse_onnx_to_trt(onnx_file_path)

# Example usage 2
engine_file_path = 'model.engine'
load_and_inspect_engine(engine_file_path)