Segment Anything Model基本使用学习心得

2,128 阅读8分钟

Segment Anything Model(简称SAM),是一个可以通过提示来进行图形分割的模型,非常的好用及方便。提示是通过位置点(point)和方形区域(box)进行提示的。

官方Demo链接,非常的酷: Segment Anything | Meta AI (segment-anything.com)

Github:github.com/facebookres…

文章将主要以predictor_example.ipynb为例,

  • 本地运行以理解该模型如何使用
  • 该文章比较适合有开发基础但又对Deeplearning不太熟悉的小白。
  • 该项目可以在mac电脑上以CPU运行。
  • 该项目可以在Colab上直接运行,传送门

Environment Set-up

using_colab = False
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
    
    !mkdir images
    !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg
    !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/groceries.jpg
        
    !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

解析: 本地使用,请先将整个repo下载到本地,通过vscode将该repo打开,并通过pip或conda安装依赖。因此也就无需运行以上代码了。

pip install opencv-python matplotlib torch torchvision

另外请下载模型checkpoint到segment-anything/notebooks目录下。下载地址如下:

其他checkpoint都比这个小,猜想效果没有这个好,就不列出来了。

checkpoint知识点插播:

"Model Checkpoints"是指在训练神经网络模型过程中,定期保存模型的中间状态。模型训练可能需要很长时间,尤其是对于复杂的深度学习模型来说,可能需要数小时、数天甚至数周。在这个过程中,模型会不断更新和调整权重,以最小化损失函数并提高性能。

为了避免在训练过程中发生意外情况(如计算机崩溃、断电等)导致整个训练过程丢失,可以通过定期保存模型的检查点来解决这个问题。模型检查点是模型在训练过程中的快照,包含了当前训练状态下的权重、偏置和其他相关参数。这样,即使在训练过程中出现意外,可以从最近的检查点重新开始训练,而无需从头开始。

通过保存模型检查点,可以实现以下目标:

  1. 防止数据丢失:避免在训练过程中发生意外情况导致数据丢失。
  2. 恢复训练:如果训练过程中断,可以从最近的检查点重新开始训练,而不是从头开始。
  3. 参数回滚:可以根据需要回滚到之前的模型状态,以便进行比较或重新训练。

与"Model Checkpoints"相关的概念是"Model"(模型)。模型是通过训练数据和优化算法学习到的参数集合,用于进行预测或执行特定任务。模型包含了神经网络的结构和相应的权重、偏置等参数。模型检查点是对模型的一次保存,用于在训练过程中备份和恢复模型的状态。

简而言之,模型是经过训练得到的结果,而模型检查点是在训练过程中定期保存模型状态的快照,以便在需要时恢复训练或回滚到先前状态。

Set-up

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    

导入依赖即可

  • show_mask主要是将通过模型预测出来的物体形状以mask的形式在图片上画出
  • show_points是将用来提示的点,展示在图片上
  • show_box是将用来提示的方形区域,展示在图片上

show_mask知识点插播:

具体解释如下:

  1. color = np.array([30/255, 144/255, 255/255, 0.6]) 这行代码创建一个包含四个元素的NumPy数组 color,表示颜色的RGBA值(红、绿、蓝和透明度)。这里的颜色值是以0到1的浮点数表示的,其中透明度为0.6。
  2. h, w = mask.shape[-2:] 这行代码获取掩码的形状(shape)中的高度(h)和宽度(w)。掩码是一个多维数组,通过获取最后两个维度的大小来获取高度和宽度。
  3. mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 这行代码将掩码进行形状变换,使其与颜色数组具有相同的形状。mask.reshape(h, w, 1) 将掩码变形为高度为h、宽度为w、通道数为1的三维数组。color.reshape(1, 1, -1) 将颜色数组变形为高度为1、宽度为1、通道数为4的三维数组。然后,使用逐元素相乘的方式将掩码与颜色进行相乘,得到一个带有颜色的掩码图像。
  4. ax.imshow(mask_image) 这行代码使用图形界面(如Matplotlib)的 imshow 方法将带有颜色的掩码图像显示在轴(ax)上。这将在图形界面中显示一个具有指定颜色的区域,该区域由掩码定义。

总结起来,这段代码的目的是将给定的掩码(mask)与指定的颜色相乘,创建一个带有颜色的图像,并在图形界面上显示该图像。它利用了NumPy数组操作和图形界面库的功能来实现这个功能。

Example image

这段没啥好讲的,展示原图,先看看样子

Selecting objects with SAM

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

# device = "cuda"
device = "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

加载模型的checkpoint。这里如果用mac或不支持cuda设备,可以像我一样,将device设置为CPU。

predictor.set_image(image)

通过调用SamPredictor.set_image对图像进行处理,以生成一个图像嵌入(image embedding)。SamPredictor将记住这个嵌入,并在后续的掩码预测中使用它。因此这步是比较耗时的,用本人的mac电脑花费了31秒的时间。 打印predictor结果如下: <segment_anything.predictor.SamPredictor object at 0x7fe6f930de10>

image.png

image embedding知识点插播

图像嵌入(Image Embedding)是一种将图像转换为向量表示的技术,旨在将图像的语义和特征以一种紧凑且可计算的方式进行编码。图像嵌入是深度学习中的一个重要概念,它通过将图像映射到一个低维向量空间,使得图像的语义信息能够以向量的形式进行表达和处理。

以下是一些关键点来详细介绍图像嵌入的概念:

  1. 特征提取:图像嵌入的核心思想是从图像中提取有意义的特征。深度学习模型(如卷积神经网络)常用于提取图像的低级特征(如边缘、纹理)和高级语义特征(如物体、场景等)。这些特征可以通过在预训练模型上进行特征提取或端到端的训练来获取。

  2. 降维:图像通常由大量像素组成,因此直接使用原始像素作为图像表示会导致高维度的特征向量。为了降低计算复杂度和提高表达效果,图像嵌入通常将特征向量映射到一个低维空间。这个低维向量空间的维度通常远远小于原始像素的数量,但仍然能够保留图像的重要信息。

  3. 语义一致性:图像嵌入的目标是使相似图像在嵌入空间中更加接近。这意味着在嵌入空间中,具有相似语义的图像在向量距离上更接近,而语义不相关的图像则更远离。这种语义一致性有助于在嵌入空间中进行图像检索、聚类和分类等任务。

  4. 应用:图像嵌入可以在多个计算机视觉任务中发挥作用。例如,通过计算图像之间的嵌入距离,可以进行图像检索和相似图像推荐。在图像聚类中,可以使用嵌入空间中的聚类算法将具有相似语义的图像分组。此外,图像嵌入还可以作为其他任务(如图像生成、目标检测、图像分割等)的输入或中间表示。

总结来说,图像嵌入是将图像转换为低维度向量表示的技术。它通过提取图像特征和降维,将图像的语义信息编码为紧凑且可计算的向量表示。图像嵌入在计算机视觉领域具有广泛的应用,可以用于图像检索、图像聚类、图像生成等任务。通过图像嵌入,我们可以将图像的语义信息以向量的形式表示,从而方便进行计算和比较,以实现各种图像相关的任务。

input_point = np.array([[500, 375]])
input_label = np.array([1])

plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show() 

input_point是一个数组,横纵坐标代表了用来提示用的点坐标。这里只有一个提示点。 input_label也是一个数组,用来对应input_point中每一个点坐标是用来添加还是移除这个提示点来分割物体。1代表添加,0代表移除。这里只有一个点,且数值为1,代表添加这个提示点。

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()  

将所有的预测都展示出来,并画出提示点和识别后的物体(用mask展示出来)

本人也用自己的图片测试了几次,效果不错,推荐。