VLM模型中高分辨率图像降低token数的几种方式

本文调研整理了VLM常用的高分辨率图像降低token数的方法,包括qformer、pooling、Ldp、s2wrapper等。

  1. cross-attention/Qformer

1024 --> 96

  1. concat + mlp

把相邻的4个 token concat到一起,然后用线性层映射到1个token

  1. Pooling

相邻的4个 token 做 pooling

  1. Pixel shuffle + mlp

Pixel shuffle 原本用于图像高分辨率(把图像通道数拆分到空间维度上)

pixelshuffle算法的实现流程如下图,其实现的功能是:将一个H × W的低分辨率输入图像(Low Resolution),通过Sub-pixel操作将其变为rH x rW的高分辨率图像(High Resolution)。

但是其实现过程不是直接通过插值等方式产生这个高分辨率图像,而是通过卷积先得到 r^2个通道的特征图(特征图大小和输入低分辨率图像一致),然后通过周期筛选(periodic shuffing)的方法得到这个高分辨率的图像,其中r为上采样因子(upscaling factor),也就是图像的扩大倍率。

但InternVL 把该方法用在降低分辨率上。

    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        return x

448448 --> 896896

336336 --> 672672

1. [1, 1024, 3200]
2. [1, 32, 32, 3200]
3. [1, 32, 16, 6400]
4. [1, 16, 32, 6400]
5. [1, 16, 16, 12800]
6. [1, 256, 12800] 

1024 * 4096
256 * 4096

5. ## LDPv2 (MobileVLM)

  1. 网络结构:

    1. LLM:mobile llama (1.4B、2.7B)

    2. vision:VIT-L-14 336px

    3. projector:LDPv2(Lightweight Downsample Projector v2)

      1. Depth-wise conv(特征变换,代码中实际用的是两层全连接)
      2. Average pooling (降低token,指定到12*12=144个)
      3. PEG Point-wise conv + residual(增强位置信息) —— 得益于这一步,视觉 token 数可以显著压缩到144
  2. 训练方式:

    1. Pretrain 和 sft 两个阶段,都打开 LLM 的训练
  3. SFT 阶段用更多的高质量数据(665k --> 3.6M, 翻了5倍)


class FeatureIRLayer(nn.Module):
    def __init__(self, in_dim: int, out_dim: int) -> None:
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp(x)


class TokenDownLayer(nn.Module):
    def __init__(self, shape) -> None:
        super().__init__()
        self.dwn = nn.Sequential(
            nn.AdaptiveAvgPool2d(shape)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, num_tokens, c = x.shape
        h = int(math.sqrt(num_tokens))
        assert h * h == num_tokens
        x = x.permute(0, 2, 1).reshape(b, -1, h, h)
        x = self.dwn(x)
        x = x.flatten(2).transpose(1, 2)
        return x


class PosInjectLayer(nn.Module):
    # https://github.com/Meituan-AutoML/Twins/blob/main/gvt.py
    def __init__(self, in_dim: int, out_dim: int, stride: int = 1) -> None:
        super().__init__()
        self.peg = nn.Sequential(
            nn.Conv2d(in_dim, out_dim, 3, stride, 1, bias=True, groups=out_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, num_tokens, c = x.shape
        h = int(math.sqrt(num_tokens))
        assert h * h == num_tokens
        cnn_feat = x.transpose(1, 2).view(b, c, h, h)
        x = self.peg(cnn_feat) + cnn_feat
        x = x.flatten(2).transpose(1, 2)
        return x

class LDPNetV2Projector(nn.Module):
    def __init__(self, config=None):
        super().__init__()
        inc, ouc = config.mm_hidden_size, config.hidden_size
        self.mlp = FeatureIRLayer(inc, ouc)
        self.dwn = TokenDownLayer((12, 12))
        self.peg = PosInjectLayer(ouc, ouc, stride=1)

    def forward(self, x):
        x = self.mlp(x)
        x = self.dwn(x)
        x = self.peg(x)
        return x
  1. S^2 wrapper

注意图中的维度变换,很清晰

[B, N, D] 在特征维度D上concat,而不是token length N上concat