介绍YOLO-NAS:最有效的目标检测算法之一

1,496 阅读8分钟

YOLO(You Only Look Once)是一种物体检测算法,使用深度神经网络模型,特别是卷积神经网络,实时检测和分类对象。自推出以来,由于其高准确性和速度,YOLO已成为最受欢迎的物体检测和分类任务算法之一。它在各种物体检测基准测试上实现了最先进的性能。

image.png

最近,于2023年5月的第一周,机器学习领域发布了YOLO-NAS模型,并且它具有无与伦比的精度和速度,超越了其他模型,如YOLOv7和YOLOv8。

image.png

YOLO-NAS模型是在COCO和Objects365等数据集上预先训练的,这使其适用于实际应用。它当前可在Deci的SuperGradients上使用,后者是一种基于PyTorch的库,包含近40个预先训练的模型,用于执行不同的计算机视觉任务,例如分类、检测、分割等。

现在让我们开始安装SuperGradients库,以开始使用YOLO-NAS!

# 安装 supergradients 库
!pip install super -gradients== 3.1 .0

导入和加载 YOLO-NAS

#importing models from supergradients' training module 
from super_gradients.training导入模型

下一步是启动模型。YOLO-NAS 有不同的型号,对于这款笔记本,我们将使用yolo_nas_l, 和pretrained_weights = 'coco'.

您可以在GitHub 页面中获取有关不同模型的更多信息。

# 初始化模型
yolo_nas = models.get( "yolo_nas_l" , pretrained_weights = "coco" )

模型架构

在下面的代码单元中,我们使用torchinfo 的摘要来获取 YOLO-NAS 架构,这有助于深入了解模型的运行方式。

# Yolo NAS 架构
!pip install torchinfo 
from torchinfo import summary 

summary(model = yolo_nas, 
       input_size = ( 16 , 3 , 640 , 640 ), 
       col_names = [ 'input_size' , 
                   'output_size' , 
                   'num_params' , 
                   'trainable' ], 
       col_width = 20 , 
       row_settings = [ 'var_names' ])
=================================================================================================================================================
Layer (type (var_name))                                           Input Shape          Output Shape         Param #              Trainable
=================================================================================================================================================
YoloNAS_L (YoloNAS_L)                                             [16, 3, 640, 640]    [16, 8400, 4]        --                   True
├─NStageBackbone (backbone)                                       [16, 3, 640, 640]    [16, 96, 160, 160]   --                   True
│    └─YoloNASStem (stem)                                         [16, 3, 640, 640]    [16, 48, 320, 320]   --                   True
│    │    └─QARepVGGBlock (conv)                                  [16, 3, 640, 640]    [16, 48, 320, 320]   3,024                True
│    └─YoloNASStage (stage1)                                      [16, 48, 320, 320]   [16, 96, 160, 160]   --                   True
│    │    └─QARepVGGBlock (downsample)                            [16, 48, 320, 320]   [16, 96, 160, 160]   88,128               True
│    │    └─YoloNASCSPLayer (blocks)                              [16, 96, 160, 160]   [16, 96, 160, 160]   758,594              True
│    └─YoloNASStage (stage2)                                      [16, 96, 160, 160]   [16, 192, 80, 80]    --                   True
│    │    └─QARepVGGBlock (downsample)                            [16, 96, 160, 160]   [16, 192, 80, 80]    351,360              True
│    │    └─YoloNASCSPLayer (blocks)                              [16, 192, 80, 80]    [16, 192, 80, 80]    2,045,315            True
│    └─YoloNASStage (stage3)                                      [16, 192, 80, 80]    [16, 384, 40, 40]    --                   True
│    │    └─QARepVGGBlock (downsample)                            [16, 192, 80, 80]    [16, 384, 40, 40]    1,403,136            True
│    │    └─YoloNASCSPLayer (blocks)                              [16, 384, 40, 40]    [16, 384, 40, 40]    13,353,733           True
│    └─YoloNASStage (stage4)                                      [16, 384, 40, 40]    [16, 768, 20, 20]    --                   True
│    │    └─QARepVGGBlock (downsample)                            [16, 384, 40, 40]    [16, 768, 20, 20]    5,607,936            True
│    │    └─YoloNASCSPLayer (blocks)                              [16, 768, 20, 20]    [16, 768, 20, 20]    22,298,114           True
│    └─SPP (context_module)                                       [16, 768, 20, 20]    [16, 768, 20, 20]    --                   True
│    │    └─Conv (cv1)                                            [16, 768, 20, 20]    [16, 384, 20, 20]    295,680              True
│    │    └─ModuleList (m)                                        --                   --                   --                   --
│    │    └─Conv (cv2)                                            [16, 1536, 20, 20]   [16, 768, 20, 20]    1,181,184            True
├─YoloNASPANNeckWithC2 (neck)                                     [16, 96, 160, 160]   [16, 96, 80, 80]     --                   True
│    └─YoloNASUpStage (neck1)                                     [16, 768, 20, 20]    [16, 192, 20, 20]    --                   True
│    │    └─Conv (reduce_skip1)                                   [16, 384, 40, 40]    [16, 192, 40, 40]    74,112               True
│    │    └─Conv (reduce_skip2)                                   [16, 192, 80, 80]    [16, 192, 80, 80]    37,248               True
│    │    └─Conv (downsample)                                     [16, 192, 80, 80]    [16, 192, 40, 40]    332,160              True
│    │    └─Conv (conv)                                           [16, 768, 20, 20]    [16, 192, 20, 20]    147,840              True
│    │    └─ConvTranspose2d (upsample)                            [16, 192, 20, 20]    [16, 192, 40, 40]    147,648              True
│    │    └─Conv (reduce_after_concat)                            [16, 576, 40, 40]    [16, 192, 40, 40]    110,976              True
│    │    └─YoloNASCSPLayer (blocks)                              [16, 192, 40, 40]    [16, 192, 40, 40]    2,595,716            True
│    └─YoloNASUpStage (neck2)                                     [16, 192, 40, 40]    [16, 96, 40, 40]     --                   True
│    │    └─Conv (reduce_skip1)                                   [16, 192, 80, 80]    [16, 96, 80, 80]     18,624               True
│    │    └─Conv (reduce_skip2)                                   [16, 96, 160, 160]   [16, 96, 160, 160]   9,408                True
│    │    └─Conv (downsample)                                     [16, 96, 160, 160]   [16, 96, 80, 80]     83,136               True
│    │    └─Conv (conv)                                           [16, 192, 40, 40]    [16, 96, 40, 40]     18,624               True
│    │    └─ConvTranspose2d (upsample)                            [16, 96, 40, 40]     [16, 96, 80, 80]     36,960               True
│    │    └─Conv (reduce_after_concat)                            [16, 288, 80, 80]    [16, 96, 80, 80]     27,840               True
│    │    └─YoloNASCSPLayer (blocks)                              [16, 96, 80, 80]     [16, 96, 80, 80]     2,546,372            True
│    └─YoloNASDownStage (neck3)                                   [16, 96, 80, 80]     [16, 192, 40, 40]    --                   True
│    │    └─Conv (conv)                                           [16, 96, 80, 80]     [16, 96, 40, 40]     83,136               True
│    │    └─YoloNASCSPLayer (blocks)                              [16, 192, 40, 40]    [16, 192, 40, 40]    1,280,900            True
│    └─YoloNASDownStage (neck4)                                   [16, 192, 40, 40]    [16, 384, 20, 20]    --                   True
│    │    └─Conv (conv)                                           [16, 192, 40, 40]    [16, 192, 20, 20]    332,160              True
│    │    └─YoloNASCSPLayer (blocks)                              [16, 384, 20, 20]    [16, 384, 20, 20]    5,117,700            True
├─NDFLHeads (heads)                                               [16, 96, 80, 80]     [16, 8400, 4]        --                   True
│    └─YoloNASDFLHead (head1)                                     [16, 96, 80, 80]     [16, 68, 80, 80]     --                   True
│    │    └─ConvBNReLU (stem)                                     [16, 96, 80, 80]     [16, 128, 80, 80]    12,544               True
│    │    └─Sequential (cls_convs)                                [16, 128, 80, 80]    [16, 128, 80, 80]    147,712              True
│    │    └─Conv2d (cls_pred)                                     [16, 128, 80, 80]    [16, 80, 80, 80]     10,320               True
│    │    └─Sequential (reg_convs)                                [16, 128, 80, 80]    [16, 128, 80, 80]    147,712              True
│    │    └─Conv2d (reg_pred)                                     [16, 128, 80, 80]    [16, 68, 80, 80]     8,772                True
│    └─YoloNASDFLHead (head2)                                     [16, 192, 40, 40]    [16, 68, 40, 40]     --                   True
│    │    └─ConvBNReLU (stem)                                     [16, 192, 40, 40]    [16, 256, 40, 40]    49,664               True
│    │    └─Sequential (cls_convs)                                [16, 256, 40, 40]    [16, 256, 40, 40]    590,336              True
│    │    └─Conv2d (cls_pred)                                     [16, 256, 40, 40]    [16, 80, 40, 40]     20,560               True
│    │    └─Sequential (reg_convs)                                [16, 256, 40, 40]    [16, 256, 40, 40]    590,336              True
│    │    └─Conv2d (reg_pred)                                     [16, 256, 40, 40]    [16, 68, 40, 40]     17,476               True
│    └─YoloNASDFLHead (head3)                                     [16, 384, 20, 20]    [16, 68, 20, 20]     --                   True
│    │    └─ConvBNReLU (stem)                                     [16, 384, 20, 20]    [16, 512, 20, 20]    197,632              True
│    │    └─Sequential (cls_convs)                                [16, 512, 20, 20]    [16, 512, 20, 20]    2,360,320            True
│    │    └─Conv2d (cls_pred)                                     [16, 512, 20, 20]    [16, 80, 20, 20]     41,040               True
│    │    └─Sequential (reg_convs)                                [16, 512, 20, 20]    [16, 512, 20, 20]    2,360,320            True
│    │    └─Conv2d (reg_pred)                                     [16, 512, 20, 20]    [16, 68, 20, 20]     34,884               True
=================================================================================================================================================
Total params: 66,976,392
Trainable params: 66,976,392
Non-trainable params: 0
Total mult-adds (T): 1.04
=================================================================================================================================================
Input size (MB): 78.64
Forward/backward pass size (MB): 27238.60
Params size (MB): 178.12
Estimated Total Size (MB): 27495.37
=================================================================================================================================================

图像上的对象检测

我们现在可以测试模型在不同图像上检测对象的能力。

在下面的代码中,我们启动了一个名为image的变量,它接收一个包含图像的 URL。然后,我们可以使用predictshow方法来显示带有模型预测的图像。

image = "https://i.pinimg.com/736x/b4/29/48/b42948ef9202399f13d6e6b3b8330b20.jpg"
 yolo_nas.predict(image).show()

image.png

YOLO-NAS:图像上的目标检测

在上图中,我们可以看到对每个对象进行的检测以及模型在其自身预测中的置信度分数。例如,我们可以看到该模型对地板上的白色物体是杯子的置信度得分为 97%。然而,这张图片中有很多物体,我们可以看到模型将 Nintendo 64 游戏机误认为是汽车。

我们可以通过使用conf作为检测阈值的参数来改进我们的结果。例如,我们可以将此值更改为conf = 0.50,以便模型仅显示置信度得分高于 50% 的检测。让我们试试看。

image = "https://i.pinimg.com/736x/b4/29/48/b42948ef9202399f13d6e6b3b8330b20.jpg"
 yolo_nas.predict(image, conf = 0.50 ).show()

image.png

YOLO-NAS:图像上的目标检测

现在,该模型仅显示其检测中置信度分数至少为 50% 的对象,即杯子、电视和遥控器。

我们可以测试更多图像。

image.png

YOLO-NAS:图像上的目标检测

image.png

YOLO-NAS:图像上的目标检测

结论

我们使用新发布的YOLO-NAS模型对图像执行了初始目标检测任务。

但是,重要的是要强调您可以使用自定义数据集微调此模型,从而提高其在某些对象上的性能。