SAM通过transformer将点、框、Mask、文本等prompt和图片进行编码学习,可以实现对图片任意目标的分割
什么是 SAM ?
- a)SAM 利用“图片-分割提示”实现对图片上任意目标的分割,分割提示包括:点、框、Mask、文本
- b) SAM 首先利用 prompt encoder 编码"分割提示",利用 image encoder 编码“图片”,然后通过 Mask decoder 解析输出 Mask
- c)SAM 利用数据驱动去做模型训练,模型输出结果后再输入模型训练
SAM 的网络结构?
- image encoder:类似 VIT 的过程,输入 image (1,3, H, W), 输出 image_embedding (1, C, H/16, W/16),即 (1, HW/256, C)的 tokens 表示
- mask:mask prompt,直接和image_embedding相加即可
- prompt encoder:包含3种提示的编码过程,其中点、框按位置被编码为Pos embedding(1,N,C),文本通过clip模型被编码为Pos embedding(1,M,C)
- mask decoder:根据image_embedding和prompt encoder输出,结合IOU tokens(1,1,C)和mask tokens(1,P,C),解析出目标mask(1,1+P+N+M, H/16, W/16)和iou(1,1+P+N+M)
SAM 的 image encoder?
- 类似 VIT 的 encoder 过程,输入 image (1,3, H, W), 输出 image_embedding (1, C, H/16, W/16),即 (1, HW/256, C)的 tokens 表示
-
image_encoder=ImageEncoderViT(..) # batched_input={List,List} -> torch.Size([2, 3, 1024, 1024]) input_images = torch.stack([preprocess(x["image"]) for x in batched_input], dim=0) # torch.Size([2, 3, 1024, 1024]) -> torch.Size([2, 256, 64, 64]) image_embeddings = image_encoder(input_images)
SAM 的 prompt encoder?
- 包含3种提示的编码过程,其中点、框按位置被编码为 Pos embedding (1, N, C),文本通过 clip 模型被编码为 Pos embedding (1, M, C),最终输出(1,N+M,C )的稀疏编码sparse_embeddings
- point&box:每个点编码为1个 pos embedding,每个 box 编码为2个 pos embedding(box 被两个点定义)
embed_dim=256 num_point_embeddings: int = 4 # pos/neg point + 2 box corners point_embeddings = [nn.Embedding(1, embed_dim) for i in range(num_point_embeddings)] point_embeddings = nn.ModuleList(point_embeddings) not_a_point_embed = nn.Embedding(1, embed_dim) # point prompt points = points + 0.5 # Shift to center of pixel # 根据点位置points,在输入(1024,1024)的基础上生成pos embedding point_embedding = pe_layer.forward_with_coords(points, input_image_size) #torch.Size([1,3,2])+(1024,1024)->torch.Size([1,3,256]) # 点有3类,-1表示非嵌入点,此时不使用pos embedding,0表示正样本点,1表示负样本点 point_embedding[labels == -1] = 0.0 point_embedding[labels == -1] += not_a_point_embed.weight point_embedding[labels == 0] += point_embeddings[0].weight point_embedding[labels == 1] += point_embeddings[1].weight # box prompt boxes = boxes + 0.5 # Shift to center of pixel coords = boxes.reshape(-1, 2, 2) # 一个框肯定2个点 corner_embedding = pe_layer.forward_with_coords(coords, input_image_size) corner_embedding[:, 0, :] += point_embeddings[2].weight #框第一个点 corner_embedding[:, 1, :] += point_embeddings[3].weight #框第二个点 # 汇总point、box编码 sparse_embeddings = torch.empty((1, 0, embed_dim)) sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) - text:通过CLIP模型将文本编码到(1,M,C)
SAM的mask prompt如何处理?
- mask利用CNN输出和image_embedding(1,C,H/16,W/16)一样大小的编码,后续直接相加
mask_downscaling = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4),
activation(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans),
activation(),
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)
mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
no_mask_embed = nn.Embedding(1, embed_dim)
if masks is not None:
dense_embeddings = self._embed_masks(masks) # 利用CNN生成mask embedding
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
) # 随机初始化生成mask embedding
SAM 的 mask decoder?
- 输入:image_embedding(1, C, H/16, W/16)、image_embedding大小的位置编码image_pe(1, C, H/16, W/16)、稀疏提示编码sparse_prompt_embeddings(1, N, C)、密集提示编码dense_prompt_embeddings(1,C,H/16, W/16)
- (1)tansformer整合所有编码:将image_embedding+dense_prompt_embeddings视为transformer encoder的k,image_pe视为pos embedding,sparse_prompt_embeddings视为decoder的q,并且参考VIT的class_token,不直接使用sparse_prompt_embeddings输出作为最终结果,而是另外生成1个iou token和P个mask token作为最终结果,所以输入transformer decoder的token变为(1,1+P+N,C),经过transformer后decoder和encoder分别输出hs(1,1+P+N,C), src(1,HW/256,C);
num_multimask_outputs=3 transformer_dim=256 iou_token = nn.Embedding(1, transformer_dim) num_mask_tokens = num_multimask_outputs + 1 mask_tokens = nn.Embedding(num_mask_tokens, transformer_dim) # Concatenate output tokens output_tokens = torch.cat([iou_token.weight, mask_tokens.weight], dim=0) # torch.Size([5, 256]) output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) # torch.Size([1, 5, 256]) tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # torch.Size([1, 12, 256]) # Expand per-image data in batch direction to be per-mask src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) # torch.Size([1, 256, 64, 64]) -》torch.Size([1, 256, 64, 64]) src = src + dense_prompt_embeddings # torch.Size([1, 256, 64, 64])+torch.Size([1, 256, 64, 64]) pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) # torch.Size([1, 256, 64, 64]) b, c, h, w = src.shape # Run the transformer torch.Size([1, 256, 64, 64]),torch.Size([1, 256, 64, 64]),torch.Size([1, 12, 256]) hs, src = transformer(src, pos_src, tokens) # torch.Size([1, 12, 256]) torch.Size([1, 4096, 256]) = q,k iou_token_out = hs[:, 0, :] # torch.Size([1, 256]) mask_tokens_out = hs[:, 1 : (1 + num_mask_tokens), :] # torch.Size([1, 4, 256]) - (2)生成Mask预测:取hs的第1-P个token作为预测结果mask_tokens_out,src经过反卷积上采样4倍,输出upscaled_embedding(1,HW/16,C'),mask_tokens_out经过MLP操作,将隐变量长度变为C',即输出hyper_in(1,P,C'),hyper_in与upscaled_embedding点乘后输出masks(1,P,HW/16),表示p个mask
self.output_upscaling = nn.Sequential( nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), LayerNorm2d(transformer_dim // 4), activation(), nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), activation(), ) self.output_hypernetworks_mlps = nn.ModuleList( [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for i in range(self.num_mask_tokens)]) # Upscale mask embeddings and predict masks using the mask tokens src = src.transpose(1, 2).view(b, c, h, w) # torch.Size([1, 256, 64, 64]) upscaled_embedding = self.output_upscaling(src) # torch.Size([1, 256, 64, 64]) -》torch.Size([1, 32, 256, 256]) hyper_in_list: List[torch.Tensor] = [] for i in range(self.num_mask_tokens): hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) # torch.Size([1, 32])x4 hyper_in = torch.stack(hyper_in_list, dim=1) # torch.Size([1, 4, 32]) b, c, h, w = upscaled_embedding.shape # torch.Size([1, 32, 256, 256]) # 运算符@表示矩阵的点乘 masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) # torch.Size([1, 4, 32]) @ torch.Size([1, 32, 256, 256]) -> torch.Size([1, 4, 256, 256]) - (3)生成IOU预测:取hs的第1个token作为预测结果iou_token_out,然后使用MLP将隐变量长度变为P,表示P各mask的iou预测
iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, num_mask_tokens, iou_head_depth) # Generate mask quality predictions iou_pred = iou_prediction_head(iou_token_out) # torch.Size([1,256]) -> torch.Size([1, 4])
SAM 如何直接分割所有目标?
- 以原图所有cell作为point prompt输入,输出Mask和iou后,通过iou阈值过滤mask,得到所有目标的mask
参考: