TensorRT_sample_onnx_MNIST逐行中文解析

459 阅读14分钟

逐行中文解析

/*
 * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

//!
//! sampleOnnxMNIST.cpp
//! This file contains the implementation of the ONNX MNIST sample. It creates the network using
//! the MNIST onnx model.
//! It can be run with the following command line:
//! Command: ./sample_onnx_mnist [-h or --help] [-d=/path/to/data/dir or --datadir=/path/to/data/dir]
//! [--useDLACore=<int>]
//!

// Define TRT entrypoints used in common code
#define DEFINE_TRT_ENTRYPOINTS 1 
#define DEFINE_TRT_LEGACY_PARSER_ENTRYPOINT 0
// DEFINE_TRT_ENTRYPOINTS宏用于定义TensorRT常用代码中的入口点,这些入口点通常用于构建网络、运行推理等操作。
// DEFINE_TRT_LEGACY_PARSER_ENTRYPOINT宏用于定义旧的解析器入口点,用于在TensorRT中解析旧版本的模型。

// 在C++中,使用尖括号<>包围的是标准库头文件,而使用双引号""包围的是用户自定义的头文件或非标准库头文件。尖括号用于包含标准库头文件,编译器会在标准库的目录中查找这些头文件;双引号用于包含用户自定义的头文件或非标准库头文件,编译器会在当前目录或指定的目录中查找这些头文件。
#include "argsParser.h" // 解析命令行参数,帮助配置TensorRT模型的参数。
#include "buffers.h" // 包含了用于管理输入和输出缓冲区的功能,帮助在TensorRT中处理数据。
#include "common.h" 
#include "logger.h"
#include "parserOnnxConfig.h" // 解析ONNX模型配置

#include "NvInfer.h" // TensorRT的主要头文件,定义了TensorRT的核心类和函数。
#include <cuda_runtime_api.h> // CUDA运行时API的头文件,用于与CUDA进行交互。

#include <cstdlib> //  C标准库中的头文件,提供了一些常用的函数,如内存分配和转换函数。
#include <fstream> // 文件流库的头文件,用于文件的输入输出操作
#include <iostream> // 输入输出流库的头文件,用于标准输入输出操作
#include <sstream> // 字符串流库的头文件,用于在内存中操作字符串
using namespace nvinfer1; // TensorRT的命名空间,包含了TensorRT的主要类和函数。
using samplesCommon::SampleUniquePtr; // 自定义的智能指针类,用于管理资源的所有权

const std::string gSampleName = "TensorRT.sample_onnx_mnist"; // 存储了TensorRT样本的名称

//! \brief  SampleOnnxMNIST类实现了ONNX MNIST示例
//!
//! \details 使用ONNX模型创建网络
//!
class SampleOnnxMNIST
{
public:
    /**
     * @brief 构造函数,初始化SampleOnnxMNIST对象
     * 
     * @param params ONNX示例参数
     它接受一个类型为samplesCommon::OnnxSampleParams的参数params,
     并使用成员初始化列表初始化了类的成员变量
     mParams为传入的参数params,
     mRuntime初始化为nullptr,
     mEngine初始化为nullptr。
     构造函数的主体部分为空,表示没有额外的构造逻辑需要执行。
     */
    /**
     * @brief 构造函数,初始化SampleOnnxMNIST对象
     * 
     * @param params ONNX示例参数
     * 该构造函数接受一个类型为samplesCommon::OnnxSampleParams的参数params,
     * 并使用成员初始化列表初始化了类的成员变量
     * mParams为传入的参数params,
     * mRuntime初始化为nullptr,
     * mEngine初始化为nullptr。
     * 构造函数的主体部分为空,表示没有额外的构造逻辑需要执行。
     */
    SampleOnnxMNIST(const samplesCommon::OnnxSampleParams& params)
        : mParams(params)
        , mRuntime(nullptr)
        , mEngine(nullptr)
    {
    }

    /**
     * @brief 构建网络引擎的函数
     */
    bool build();

    /**
     * @brief 运行此示例的TensorRT推理引擎
     */
    bool infer();

private:
    samplesCommon::OnnxSampleParams mParams; //!< 用于存储示例参数。

    nvinfer1::Dims mInputDims;  //!< 网络输入的维度。
    nvinfer1::Dims mOutputDims; //!< 网络输出的维度。
    int mNumber{0};             //!< 待分类的数字

    std::shared_ptr<nvinfer1::IRuntime> mRuntime;   //!< 用于反序列化引擎的TensorRT运行时
    std::shared_ptr<nvinfer1::ICudaEngine> mEngine; //!< 用于运行网络的TensorRT引擎

    //!
    //! \brief 解析MNIST的ONNX模型并创建TensorRT网络
    //!
    //!
    //! \brief 构建网络并解析MNIST的ONNX模型
    //!
    //! \param builder 构建器对象
    //! \param network 网络定义对象
    //! \param config 构建器配置对象
    //! \param parser ONNX解析器对象
    //!
    bool constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,
        SampleUniquePtr<nvinfer1::INetworkDefinition>& network, 
        SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
        SampleUniquePtr<nvonnxparser::IParser>& parser);
    //!
    //! 读取输入数据并将结果存储在BufferManager管理的缓冲区中
    //!
    bool processInput(const samplesCommon::BufferManager& buffers);
    //!
    //! \brief 对数字进行分类并验证结果
    //!
    bool verifyOutput(const samplesCommon::BufferManager& buffers);
};

//!
//! \brief 创建网络,配置builder并创建 network engine
//!
//! \details 此函数通过解析Onnx模型创建Onnx MNIST网络,并构建
//!          将用于运行MNIST的引擎(mEngine)
//!
//! \return 如果成功创建引擎则返回true,否则返回false
//!
bool SampleOnnxMNIST::build()
{
    // 创建一个独特指针的构建器对象,使用createInferBuilder函数创建TensorRT构建器,传入日志记录器
    // 这行代码的意思是使用SampleUniquePtr智能指针类创建了一个指向nvinfer1::INetworkDefinition类型对象的指针network,并通过调用builder->createNetworkV2(0)来创建一个TensorRT网络。SampleUniquePtr是一个自定义的智能指针类,用于管理资源的所有权。
    // 这样写代码的好处是使用智能指针可以自动管理资源的生命周期,避免内存泄漏和资源泄漏的问题。当network指针超出作用域时,智能指针会自动释放所指向的内存,确保资源的正确释放,提高代码的健壮性和可维护性。
    // builder->createNetworkV2(0)来创建一个TensorRT网络
    // I开头的是实现接口
    // 没有I开头的是抽象接口
    /* 
    在C++中,抽象类和实现类的区别在于:
    1. 抽象类(Abstract Class):
    抽象类是包含至少一个纯虚函数(纯虚函数通过在函数声明末尾添加 = 0 来声明)的类。
    不能直接实例化抽象类的对象,只能作为基类来派生出其他类。
    抽象类用于定义接口和行为,要求派生类实现其纯虚函数。
    如果派生类没有实现抽象类中的所有纯虚函数,那么派生类也会成为抽象类。
    2. 实现类(Concrete Class):
    实现类是指没有纯虚函数的类,可以被实例化为对象。
    实现类可以直接实例化对象,不需要派生其他类。
    实现类可以继承自抽象类,并实现抽象类中的纯虚函数,从而变成一个具体的类。
    总的来说,抽象类用于定义接口和规范行为,要求派生类实现特定的功能;而实现类是具体的类,可以直接实例化对象并提供具体的实现。
    */
    auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
    if (!builder)
    {
        return false;
    }
    
    // 创建一个独特指针的网络定义对象,使用构建器创建网络
    auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(0));
    if (!network)
    {
        return false;
    }

    // 创建一个独特指针的构建器配置对象,使用构建器创建配置
    auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
    if (!config)
    {
        return false;
    }

    // 创建一个独特指针的ONNX解析器对象,使用网络和日志记录器创建解析器
    auto parser = SampleUniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger()));
    if (!parser)
    {
        return false;
    }

    // 构建网络并返回构建结果
    auto constructed = constructNetwork(builder, network, config, parser);
    if (!constructed)
    {
        return false;
    }
    // 用于构建器进行性能分析的CUDA流。
    auto profileStream = samplesCommon::makeCudaStream();
    if (!profileStream)
    {
        return false;
    }
    // 将配置文件中的性能流设置为profileStream ⬆️
    config->setProfileStream(*profileStream);

    // 使用构建器构建序列化网络,并将结果存储在plan中
    SampleUniquePtr<IHostMemory> plan{builder->buildSerializedNetwork(*network, *config)};
    if (!plan)
    {
        return false;
    }

    // 创建TensorRT运行时对象并使用日志记录器创建
    mRuntime = std::shared_ptr<nvinfer1::IRuntime>(createInferRuntime(sample::gLogger.getTRTLogger()));
    if (!mRuntime)
    {
        return false;
    }

    // 反序列化plan中的数据以创建CUDA引擎,并使用自定义的删除器
    // plan 是一个 SampleUniquePtr<IHostMemory> 类型的智能指针对象,用于管理一个主机内存资源。
    // 在这里,plan 用于存储序列化的引擎数据,
    // 通过调用 plan->data() 和 plan->size() 可以获取引擎数据的指针和大小。
    // 这些数据会被传递给 mRuntime->deserializeCudaEngine 方法进行反序列化。
    /*
    1. plan->data() 和 plan->size() 用于获取序列化的引擎数据的指针和大小。
    2. 这些数据作为参数传递给 mRuntime->deserializeCudaEngine 方法,该方法用于反序列化 CUDA 引擎。
    3. 反序列化后的 CUDA 引擎被 std::shared_ptr<nvinfer1::ICudaEngine> 类型的智能指针 mEngine 所管理。
    4. samplesCommon::InferDeleter() 用作自定义的删除器,确保在引擎不再需要时正确释放资源。
    */
    mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(
        mRuntime->deserializeCudaEngine(plan->data(), plan->size()), 
        samplesCommon::InferDeleter());
    if (!mEngine)
    {
        return false;
    }

    /*
    1个输入和1个输出
    输入dim = 4
    输出dim = 2
    输入维度为4维可能表示输入数据是一个四维张量,例如图像数据通常具有通道、高度和宽度三个维度,再加上批次维度,因此是四维的。
    输出维度为2维可能表示输出数据是一个二维张量,例如分类任务中可能输出类别的概率分布,因此是二维的。   
    */
    // 确保网络只有一个输入,Nb 可能是指网络(network)中的输入数量(Number of Inputs)的缩写。因此,network->getNbInputs() 可能是用来获取网络中输入的数量。
    ASSERT(network->getNbInputs() == 1);
    // 获取输入维度信息并存储在mInputDims中
    mInputDims = network->getInput(0)->getDimensions();
    // 确保输入维度为4
    ASSERT(mInputDims.nbDims == 4);

    // 确保网络只有一个输出
    ASSERT(network->getNbOutputs() == 1);
    // 获取输出维度信息并存储在mOutputDims中
    mOutputDims = network->getOutput(0)->getDimensions();
    // 确保输出维度为2
    ASSERT(mOutputDims.nbDims == 2);
    return true;
}
//!
//! \brief 使用 ONNX 解析器创建 Onnx MNIST 网络并标记输出层
//!
//! \param network 指向将填充 Onnx MNIST 网络的网络指针
//!
//! \param builder 指向引擎构建器的指针
//!
bool SampleOnnxMNIST::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,
    SampleUniquePtr<nvinfer1::INetworkDefinition>& network, 
    SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
    SampleUniquePtr<nvonnxparser::IParser>& parser)
{
    // 从文件中解析 ONNX 模型
    auto parsed = parser->parseFromFile(
        locateFile(mParams.onnxFileName, mParams.dataDirs).c_str(), // 解析的ONNX模型文件的路径
        // locateFile返回的是std::string类型的文件路径,因此需要通过.c_str()将其转换为C风格的字符串。
        /* std::string  vs   C语言的字符串
        1. 存储方式:
        std::string: 是C++标准库提供的字符串类,可以动态调整大小,内部维护字符串的长度和内容。
        C风格的字符串: 是以null结尾的字符数组,长度固定,需要手动管理内存。
        2. 使用方式:
        std::string: 可以直接使用+操作符进行字符串拼接,支持各种字符串操作函数。
        C风格的字符串: 需要使用C语言的字符串处理函数,如strcpy、strcat等,操作相对繁琐。
        */
        static_cast<int>(sample::gLogger.getReportableSeverity()));
        /* 
        static_cast是C++中的一种类型转换操作符,用于在编译时执行类型转换。它可以将一种数据类型转换为另一种数据类型,包括基本数据类型、指针类型和引用类型等。static_cast在编译时进行类型检查,因此在类型转换时更加安全,但需要程序员确保转换是合法的。
        在上下文中,static_cast<int>将某个值转换为int类型,以便在后续代码中使用整数值表示日志级别。
        getReportableSeverity是一个函数或方法,用于获取日志记录器(logger)的报告级别(severity)。
        在这种情况下,sample::gLogger.getReportableSeverity()可能是用于获取日志记录器sample::gLogger的报告级别,
        以便在日志记录中确定要报告的消息的严重程度。
        */
    if (!parsed)
    {
        return false;
    }


    // 根据参数设置网络构建配置
    if (mParams.fp16)
    {
        config->setFlag(BuilderFlag::kFP16); 
        // config->setFlag(BuilderFlag::kFP16)设置构建配置中的标志为kFP16,表示使用FP16精度。
    }
    if (mParams.bf16)
    {
        config->setFlag(BuilderFlag::kBF16);
    }
    /*
    https://www.paddlepaddle.org.cn/documentation/docs/zh/dev_guides/amp_precision/amp_op_dev_guide_cn.html
    
    1. BF16(BFloat16): [9.2 × 10^{−41} , 3.38953139×10^{38}] 
        BF16使用16位表示浮点数,其中1位用于符号位,8位用于指数部分,7位用于尾数部分。
        BF16提供了较高的计算速度和较小的内存占用,适用于深度学习加速器(如TPU)等场景。
        BF16的精度介于FP16和FP32之间,适用于一些计算密集型任务。
    bf16最大正值
        0 11111110 1111111 
        指数部分 254 - 127 = 127  
        $2^{127} \times (1+ 1- 2^{-7}) = 2^{127} \times 1.9921875 = 3.38953139×10^{38}$
    bf16最小正值
        0 00000000 0000001
        指数部分 1 - 127 = -126 对于非规格数实际指数固定为 1−偏置值
        尾数部分 尾数的十进制值,因为是非规格数,没有隐含的1, 所以只有 2^{-7},不需要+1
        $2^{-126} \times  2^{-7} =  2^(-133) = 9.2 × 10^{−41}$

    注意
        0 00000001 0000001 
        指数部分  1 - 127 = -126 这是正规数 
        尾数部分 尾数的十进制值,正规数有隐含的1,所以需要1+2^{-7}
        $2^{-126} \times (1 + 2^{-7})  =  6.11 \times 10^{-5}$
    2. FP16(Half Precision): [5.96 \times 10^{-8}, 65504]
        FP16也使用16位表示浮点数,其中1位用于符号位,5位用于指数部分,10位用于尾数部分。
        FP16提供了较高的计算速度和较小的内存占用,适用于GPU等设备上的深度学习计算。
        FP16的精度较低,可能会导致数值精度损失,但在一些情况下可以提高计算速度和减少内存占用。
    FP16的最大值是由以下几部分构成:
        符号位为 0(表示正数)。
        指数位为 11110(最大指数值 - 1,因为 11111 用于表示无穷大和NaN)。
        尾数位为 1111111111(尾数位全为1提供了该指数级别下可能的最大尾数)。
        计算 FP16 的最大正规数
    给定上述结构,FP16的最大正规数计算如下:
        指数计算:最大指数位为 11110,代表的指数值为 2+4+8+16 = 30 ,偏移量15 ,30-15 = 15
            选择15作为偏移量是为了在5位指数中平衡正负数的表示范围。这样的设计确保了:
                能表示的最小指数是 -14(对于非正规数)和 -15(对于正规数)。
                能表示的最大指数是 +15。
                特殊编码(全0和全1的指数)被用来表示0、非正规数、无穷大和NaN。
        尾数计算:尾数位 1111111111 表示的额外值为 1−2^{−10} =0.9990234375。
        因此,最大正规数为 1.9990234375× 2 ^15 =65504。
    fp16的最小正值是
    指数部分 1 - 15 = -14 
    (0 00000 0000000001)_2	 = 2^{-14} \times 2^{-10} = 2^{-24} = (0.000000059604645)_{10}	 =  5.96 \times 10^{-8} 最小正值

    
    总的来说,BF16提供了更好的精度和表示范围,适用于一些需要较高精度的任务,
    而FP16提供了更快的计算速度和更小的内存占用,适用于一些对精度要求不那么严格的任务。
    */


    /*
    在讨论FP16或BF16等浮点数格式的最大值、最小值以及表示范围时,我们通常指的是正规数(normalized numbers)。这是因为正规数能够使用其所有的位精确地表示数值,并利用隐含的前导一优化存储空间。非正规数通常用于填补零与最小正规数之间的间隙,允许接近零的值能够被表示,尽管这些数的精度较低。
        关于正规数和非正规数
        最大值:通常指最大的正规数,这是浮点格式可以表示的最大的非无穷大数值。
        最小值:在不同上下文中,这可能指最小的正规数或最小的正非正规数。最小的正规数是除零外,可以表示的最小的正数,并且具有完整的精度。而最小的正非正规数是可以表示的最接近零的正数,但精度较低。
    正规数(Normalized Numbers)
        正规数是浮点数格式中的一种数,其中包括:

        非零的指数。
        一个隐含的前导一(即尾数部分在数学上总是以1开始)。
        对于正规数,其二进制表示的指数部分从最小的正值开始(全0保留为表示非正规数和零),使得尾数能够用其全部位精确地表示数字,而指数则提供了数值的大小级。例如,在FP16中,最小正规指数是 -14(编码为 00001),对应于二进制表示中的实际指数 1−15=−14。

    非正规数(Subnormal Numbers)
        非正规数用于表示非常接近于零的数值,这些数值太小,不能以正规数的形式表示。在非正规数的表示中:

        指数部分为全0,这是一个特殊编码,表示这些数的实际指数比正规数的最小指数还要小一级(在FP16中为 −14−1=−15)。
        没有隐含的前导一,尾数直接从0开始,这减少了数值的精度,但增加了接近零的表示范围。
        非正规数的引入主要是为了解决“下溢”问题,即当数值过小,无法表示为正规数时,如果没有非正规数,这些值就会直接四舍五入为零。使用非正规数可以让我们表示和处理这些极小的值,虽然这样做牺牲了一些精度。
    */
    if (mParams.int8)
    {
        config->setFlag(BuilderFlag::kINT8);
        // 设置所有动态范围为指定值
        samplesCommon::setAllDynamicRanges(network.get(), 127.0F, 127.0F);
        // 调用samplesCommon::setAllDynamicRanges(network.get(), 127.0F, 127.0F)来设置所有动态范围为指定值127.0,
        // 这是因为在INT8精度下需要设置动态范围。
        /*
        在INT8量化过程中,将浮点数转换为整数时需要考虑数值的范围。
        动态范围指的是输入数据的范围,即数据的最大值和最小值。
        在INT8精度下,数据被量化为8位整数,因此需要将浮点数映射到整数范囋内。
        设置动态范围可以帮助模型在INT8精度下正确地量化和推理,确保模型的准确性和性能。
        */
    }

    // 启用深度学习加速器(DLA)
    /*
    DLA代表Deep Learning Accelerator,是一种用于加速深度学习推理的硬件加速器。dlaCore是指DLA的核心编号,用于指定在具有多个DLA核心的平台上要使用的特定DLA核心。在这里,mParams.dlaCore可能是从参数中获取的DLA核心编号,用于启用特定的DLA核心来加速深度学习推理。
    */
    samplesCommon::enableDLA(
        builder.get(), 
        config.get(),
        mParams.dlaCore);

    return true;
}

//!
//! \brief 运行TensorRT推理引擎以执行该示例
//!
//! \details 该函数是示例的主要执行函数。它分配缓冲区,设置输入并执行引擎。
//!
bool SampleOnnxMNIST::infer()
{
    // 创建RAII缓冲区管理器对象
    samplesCommon::BufferManager buffers(mEngine);
    /*
    RAII(Resource Acquisition Is Initialization)是一种C++编程技术,用于在对象的构造函数中获取资源(如内存、文件句柄等),并在对象的析构函数中释放这些资源。
    RAII的核心思想是资源的生命周期与对象的生命周期绑定,通过对象的构造和析构来管理资源的获取和释放,从而确保资源在适当的时候被正确释放,避免资源泄漏。
    在这段代码中,samplesCommon::BufferManager对象被称为RAII缓冲区管理器对象,它可能在构造函数中获取与TensorRT推理引擎相关的资源(如内存缓冲区),并在对象生命周期结束时自动释放这些资源,以确保资源的正确管理和释放。RAII是一种常见的资源管理技术,可以提高代码的健壮性和可维护性。
    */

    // 创建执行上下文对象
    auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(
        mEngine->createExecutionContext()
        );
    // 检查执行上下文是否成功创建
    if (!context)
    {
        return false;
    }

    // 遍历所有输入输出张量
    for (int32_t i = 0, e = mEngine->getNbIOTensors(); i < e; i++)
    {
        // 获取张量名称
        auto const name = mEngine->getIOTensorName(i);
        // 设置张量绑定维度
        context->setBindingDimensions(i, buffers.getBindingDimensions(name));
        /*
        在TensorRT中,设置张量绑定维度是为了确保输入和输出张量的维度匹配。
        在推理过程中,TensorRT需要知道每个张量的维度信息,以便正确地处理数据流和计算。
        通过设置张量绑定维度,可以确保输入和输出张量的维度与模型期望的维度一致,从而避免维度不匹配导致的错误或异常情况。
        这样可以确保推理过程顺利进行,并得到正确的结果。
        */

    }

    // 确保输入张量名称列表中只有一个张量
    ASSERT(mParams.inputTensorNames.size() == 1);
    if (!processInput(buffers))
    {
        return false;
    }

    // 从主机host输入缓冲区复制到设备device输入缓冲区
    // 从cpu内存转到gpu显存
    buffers.copyInputToDevice();

    // 开始执行推理操作,并返回一个布尔值来表示推理操作的执行状态。
    bool status = context->executeV2(
        buffers.getDeviceBindings().data());
        /* 
    在这段代码中,context->executeV2(buffers.getDeviceBindings().data())表示执行TensorRT的推理引擎,
    使用buffers.getDeviceBindings().data()传递gpu设备上的绑定数据。

        getDeviceBindings是一个方法,用于获取与设备相关的绑定数据。
        在TensorRT中,绑定数据是指将输入和输出张量绑定到执行上下文的过程中所需的数据。
        通过调用getDeviceBindings方法,可以获取设备上的绑定数据,以便在执行推理操作时传递给执行上下文,确保正确的数据流和计算。
        这个方法通常用于在推理过程中管理和传递设备上的数据。

    这个方法会执行推理操作并返回一个布尔值,表示推理操作的执行状态。

    这段代码 context->executeV2(buffers.getDeviceBindings().data()) 实际上是在执行推理操作。
    具体来说,它调用了TensorRT的执行上下文(IExecutionContext)的 executeV2 方法,
    传递了设备上的绑定数据作为参数,以触发推理过程。
    这个方法会开始执行推理操作,并返回一个布尔值来表示推理操作的执行状态。
    在这个语句中,推理操作是立即开始执行的,而不是等待推理完成。
    因此,这段代码会触发推理操作并继续执行后续的代码,而不会阻塞等待推理完成。
    如果需要等待推理完成,通常会在这段代码后面添加适当的逻辑来等待推理操作的结果。
    */
    if (!status)
    {
        return false;
    }

    // 从设备输出缓冲区复制到主机输出缓冲区
    // 从gpu显存转回cpu内存
    buffers.copyOutputToHost();

    // 验证结果
    if (!verifyOutput(buffers))
    {
        return false;
    }

    return true;
}

//!
//! \brief 读取输入数据并将结果存储在管理的缓冲区中
//!
bool SampleOnnxMNIST::processInput(const samplesCommon::BufferManager& buffers)
{
    const int inputH = mInputDims.d[2];
    const int inputW = mInputDims.d[3];

    // 读取一个随机数字文件
    /*
    1. srand(unsigned(time(nullptr))): 这行代码用于设置随机数生成器的种子。通过使用当前时间作为种子,可以确保每次运行程序时生成的随机数序列是不同的,增加随机性。
    2. std::vector<uint8_t> fileData(inputH * inputW):
        这行代码创建了一个名为fileData的std::vector,用于存储图像数据。
        fileData的大小为inputH * inputW,即图像的高度乘以宽度。
    3. mNumber = rand() % 10;: 
        这行代码生成一个范围在0到9之间的随机数字,并将其存储在mNumber变量中。
        这个随机数字通常用作文件名的一部分,用于读取对应的PGM文件。
    4. readPGMFile(locateFile(std::to_string(mNumber) + ".pgm", mParams.dataDirs), fileData.data(), inputH, inputW);: 
        这行代码调用readPGMFile函数,
        该函数用于读取PGM格式的图像文件并将其存储在fileData中。
        locateFile函数用于定位文件路径,
        根据随机生成的数字和数据目录来确定要读取的文件。
        读取的图像数据将存储在fileData中,以便后续处理和输入模型。
    总体而言,这段代码的目的是生成随机数作为文件名的一部分,读取对应的PGM图像文件,并将图像数据存储在fileData中供后续处理使用。
    */
    srand(unsigned(time(nullptr)));
    std::vector<uint8_t> fileData(inputH * inputW);
    mNumber = rand() % 10;
    readPGMFile(
        locateFile(
            std::to_string(mNumber) + ".pgm", 
            mParams.dataDirs), 
        fileData.data(),  // fileData是一个std::vector对象,而fileData.data()是std::vector类的成员函数,用于返回指向std::vector中存储数据的指针。在这里,fileData.data()返回的是指向fileData中存储的数据的指针,以便将这些数据传递给readPGMFile函数进行处理。通过使用fileData.data(),可以直接访问std::vector中的数据,并将其传递给其他函数或操作。
        inputH, 
        inputW);

    // 打印ASCII表示
    sample::gLogInfo << "输入:" << std::endl;
    for (int i = 0; i < inputH * inputW; i++)
    {
        sample::gLogInfo << (
            " .:-=+*#%@"[fileData[i] / 26]) << ( // 根据fileData[i]的值计算出对应的ASCII字符。fileData[i]是图像数据中的一个像素值,通过除以26来映射到ASCII字符集合中的一个字符。不同的像素值将映射到不同的字符,从而实现将图像数据转换为可视化的ASCII表示。
                ((i + 1) % inputW) ? "" : "\n"); // 当(i + 1) % inputW的结果为0时,表示已经输出了一行的像素数据,需要换行显示下一行的像素数据。如果结果不为0,则继续在同一行输出像素数据。
    }
    sample::gLogInfo << std::endl;

    // 这行代码从BufferManager中获取主机缓冲区,并将其转换为float*类型的指针。
    // 这个主机缓冲区用于存储输入数据,对应于模型的输入张量。
    float* hostDataBuffer = static_cast<float*>( // 转换为float*类型指针.模型的输入数据通常是浮点数类型,所以需要将主机缓冲区转换为float*类型的指针,以便存储浮点数数据。
        buffers.getHostBuffer( // 获取host缓冲区,从BufferManager中获取主机缓冲区。主机缓冲区是在主机内存中分配的用于存储模型输入数据的空间。
            mParams.inputTensorNames[0] 
            )
        );
    for (int i = 0; i < inputH * inputW; i++)
    {
        hostDataBuffer[i] = 1.0 - float(fileData[i] / 255.0);
        /*
        在循环中,将每个像素值从uint8_t类型转换为float类型,并进行归一化处理。
        具体地,将每个像素值除以255.0(最大像素值)并减去1.0,
        图像像素值通常在[0, 255]的范围内,通过除以255,可以将像素值缩放到[0, 1]的范围内
        以将像素值范围从[0, 255]映射到[-1.0, 0.0],适应模型的输入要求。
        */

        /*
        虽然已经将主机缓冲区的指针转换为float*类型,但仍需要对每个像素值进行类型转换和归一化处理的原因如下:
            1. 数据类型一致性:即使主机缓冲区的指针类型已经转换为float*,但实际存储在缓冲区中的数据仍然是uint8_t类型(图像像素值)。因此,在将数据存储到主机缓冲区之前,仍需要将uint8_t类型的像素值转换为float类型,以确保数据类型一致性。
            2. 归一化处理:在深度学习模型中,通常需要对输入数据进行归一化处理,以确保数据落在合适的范围内(通常是[-1.0, 1.0]或[0.0, 1.0])。在这里,将每个像素值除以255.0(最大像素值)并减去1.0,将像素值范围从[0, 255]映射到[-1.0, 0.0],以适应模型的输入要求。

        */
    }

    return true;

    /*
    这段代码是SampleOnnxMNIST类中的一个方法,名为processInput,用于处理输入数据。具体解释如下:
    1. 首先,根据输入张量的维度信息,确定输入数据的高度和宽度(inputH和inputW)。
    2. 然后,生成一个随机数字(mNumber),并读取对应数字的PGM文件数据到fileData中。PGM文件是一种图像文件格式,用于存储灰度图像数据。
    3. 接下来,通过循环遍历fileData中的数据,并将其转换为ASCII字符表示,输出到日志中,以便查看输入数据的可视化表示。
    4. 继续,将fileData中的数据转换为浮点数,并存储到hostDataBuffer中。这个hostDataBuffer是从BufferManager中获取的主机缓冲区,用于存储输入数据。
    5. 最后,返回true表示输入数据处理成功。
    数据流向如下:
    PGM文件数据 -> fileData -> ASCII字符表示 -> hostDataBuffer -> 主机缓冲区(用于输入数据)。
    */
}
//!
//! \brief 对数字进行分类并验证结果
//!
//! \return 分类输出是否符合预期
//!
bool SampleOnnxMNIST::verifyOutput(const samplesCommon::BufferManager& buffers)
{
    // 获取输出张量的大小
    const int outputSize = mOutputDims.d[1];
    // 获取输出数据的指针
    float* output = static_cast<float*>(buffers.getHostBuffer(mParams.outputTensorNames[0]));
    float val{0.0F};
    int idx{0};
    /*
    1. float val{0.0F};: 这行代码声明并初始化了一个float类型的变量val,并将其初始值设为0.0。这个变量val用于存储计算过程中的最大值。
    2. int idx{0};: 这行代码声明并初始化了一个int类型的变量idx,并将其初始值设为0。这个变量idx用于存储计算过程中的索引值。
    */

    // 计算Softmax
    /*
    $$ y_i = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}} $$
    */
    float sum{0.0F}; // sum用于存储Softmax函数中的分母部分,即指数函数的和
    for (int i = 0; i < outputSize; i++)
    {
        output[i] = exp(output[i]);
        sum += output[i];
    }

    sample::gLogInfo << "输出:" << std::endl;
    for (int i = 0; i < outputSize; i++)
    {
        // 归一化输出值
        output[i] /= sum;
        val = std::max(val, output[i]);
        if (val == output[i])
        {
            idx = i;
        }

        sample::gLogInfo << " 概率 " << i << "  " << std::fixed << std::setw(5) << std::setprecision(4) << output[i]  // std::fixed << std::setw(5) << std::setprecision(4) << output[i]: 设置输出格式,固定小数点格式,总宽度为5,小数点后保留4位,输出当前类别的概率值 output[i]
                         << " "
                         << "类别 " << i << ": " << std::string(int(std::floor(output[i] * 10 + 0.5F)), '*') // std::string(int(std::floor(output[i] * 10 + 0.5F)), '*'): 根据概率值将其转换为星号形式的可视化表示。将概率值乘以10并四舍五入,然后转换为整数,表示星号的数量,最终用星号表示概率大小。
                         << std::endl;
    }
    /*
    将 output[i] 乘以 10 并四舍五入,以将概率转换为 0 到 10 之间的整数。然后,这个整数被用作构造一个由相应数量星号组成的字符串,星号数量反映了概率的大小。
    假设输出的概率值为 [0.1234, 0.5678, 0.9102],对应3个类别。输出结果如下:
    概率 0  0.1234 类别 0: *
    概率 1  0.5678 类别 1: *****
    概率 2  0.9102 类别 2: *********
    */

    sample::gLogInfo << std::endl;

    return idx == mNumber && val > 0.9F;
    // 如果没有概率大于 0.9 的类别,则在给定的条件下,idx 的值将保持为初始赋值时的值,即 0。这是因为条件 val > 0.9F 会导致整个条件表达式的结果为 false,因此 idx 的值不会被更新为其他值。
    // 在这段代码中,idx 的值表示具有最大概率值的类别索引,而条件 val > 0.9F 要求最大概率值必须大于 0.9 才能满足条件。如果没有概率大于 0.9 的类别,则 idx 仍然保持为初始赋值时的值,不会被更新为其他值。
}

//!
//! \brief 使用命令行参数初始化params结构的成员
//!
samplesCommon::OnnxSampleParams initializeSampleParams(const samplesCommon::Args& args)
{
    // 创建一个OnnxSampleParams对象params
    samplesCommon::OnnxSampleParams params;
    
    // 如果用户未提供目录路径,则使用默认目录
    if (args.dataDirs.empty())
    {
        // 添加默认数据目录路径到params的dataDirs中
        params.dataDirs.push_back("data/mnist/");
        params.dataDirs.push_back("data/samples/mnist/");
    }
    else // 如果用户提供了数据目录,则使用用户提供的目录
    {
        // 将args中的dataDirs赋值给params的dataDirs
        params.dataDirs = args.dataDirs;
    }
    
    // 设置onnx文件名为"mnist.onnx"
    params.onnxFileName = "mnist.onnx";
    
    // 将输入张量名称"Input3"添加到params的inputTensorNames中
    params.inputTensorNames.push_back("Input3");
    
    // 将输出张量名称"Plus214_Output_0"添加到params的outputTensorNames中
    params.outputTensorNames.push_back("Plus214_Output_0");
    
    // 将args中的useDLACore赋值给params的dlaCore
    params.dlaCore = args.useDLACore;
    
    // 将args中的runInInt8赋值给params的int8
    params.int8 = args.runInInt8;
    
    // 将args中的runInFp16赋值给params的fp16
    params.fp16 = args.runInFp16;
    
    // 将args中的runInBf16赋值给params的bf16
    params.bf16 = args.runInBf16;

    return params;
}

//!
//! \brief 打印此示例的运行帮助信息
//!
void printHelpInfo()
{
    // 打印用法信息
    std::cout
        << "用法: ./sample_onnx_mnist [-h 或 --help] [-d 或 --datadir=<数据目录路径>] [--useDLACore=<整数>]"
        << std::endl;
    // 显示帮助信息
    std::cout << "--help          显示帮助信息" << std::endl;
    // 指定数据目录
    std::cout << "--datadir       指定数据目录路径,覆盖默认设置。此选项可多次使用以添加多个目录。如果未提供数据目录,则默认使用(data/samples/mnist/, data/mnist/)"
              << std::endl;
    // 指定 DLA 核心
    std::cout << "--useDLACore=N  为支持 DLA 的层指定 DLA 引擎。值范围从 0 到 n-1,其中 n 是平台上的 DLA 引擎数量。"
              << std::endl;
    // 运行 Int8 模式
    std::cout << "--int8          以 Int8 模式运行" << std::endl;
    // 运行 FP16 模式
    std::cout << "--fp16          以 FP16 模式运行" << std::endl;
    // 运行 BF16 模式
    std::cout << "--bf16          以 BF16 模式运行" << std::endl;
}

int main(int argc, char** argv)
{
    // 解析命令行参数
    samplesCommon::Args args;
    bool argsOK = samplesCommon::parseArgs(args, argc, argv);
    
    // 检查参数是否有效,如果无效则打印错误信息并显示帮助信息
    if (!argsOK)
    {
        sample::gLogError << "无效参数" << std::endl;
        printHelpInfo();
        return EXIT_FAILURE;
    }
    
    // 如果用户请求帮助信息,则显示帮助信息并返回成功
    if (args.help)
    {
        printHelpInfo();
        return EXIT_SUCCESS;
    }

    // 定义并开始测试
    auto sampleTest = sample::gLogger.defineTest(gSampleName, argc, argv);
    sample::gLogger.reportTestStart(sampleTest);

    // 初始化样本参数
    SampleOnnxMNIST sample(initializeSampleParams(args));

    // 输出信息:构建并运行一个用于 Onnx MNIST 的 GPU 推理引擎
    sample::gLogInfo << "构建并运行一个用于 Onnx MNIST 的 GPU 推理引擎" << std::endl;

    // 构建推理引擎,如果失败则报告失败
    if (!sample.build())
    {
        return sample::gLogger.reportFail(sampleTest);
    }
    
    // 进行推理,如果失败则报告失败
    if (!sample.infer())
    {
        return sample::gLogger.reportFail(sampleTest);
    }

    // 报告测试通过
    return sample::gLogger.reportPass(sampleTest);
}