本文调研整理了VLM常用的高分辨率图像降低token数的方法,包括qformer、pooling、Ldp、s2wrapper等。
-
cross-attention/Qformer
1024 --> 96
-
concat + mlp
把相邻的4个 token concat到一起,然后用线性层映射到1个token
-
Pooling
相邻的4个 token 做 pooling
-
Pixel shuffle + mlp
Pixel shuffle 原本用于图像高分辨率(把图像通道数拆分到空间维度上)
pixelshuffle算法的实现流程如下图,其实现的功能是:将一个H × W的低分辨率输入图像(Low Resolution),通过Sub-pixel操作将其变为rH x rW的高分辨率图像(High Resolution)。
但是其实现过程不是直接通过插值等方式产生这个高分辨率图像,而是通过卷积先得到 r^2个通道的特征图(特征图大小和输入低分辨率图像一致),然后通过周期筛选(periodic shuffing)的方法得到这个高分辨率的图像,其中r为上采样因子(upscaling factor),也就是图像的扩大倍率。
但InternVL 把该方法用在降低分辨率上。
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
int(c / (scale_factor * scale_factor)))
return x
448448 --> 896896
336336 --> 672672
1. [1, 1024, 3200]
2. [1, 32, 32, 3200]
3. [1, 32, 16, 6400]
4. [1, 16, 32, 6400]
5. [1, 16, 16, 12800]
6. [1, 256, 12800]
1024 * 4096
256 * 4096
5. ## LDPv2 (MobileVLM)
-
网络结构:
-
LLM:mobile llama (1.4B、2.7B)
-
vision:VIT-L-14 336px
-
projector:LDPv2(Lightweight Downsample Projector v2)
- Depth-wise conv(特征变换,代码中实际用的是两层全连接)
- Average pooling (降低token,指定到12*12=144个)
- PEG : Point-wise conv + residual(增强位置信息) —— 得益于这一步,视觉 token 数可以显著压缩到144
-
-
训练方式:
- Pretrain 和 sft 两个阶段,都打开 LLM 的训练
-
SFT 阶段用更多的高质量数据(665k --> 3.6M, 翻了5倍)
class FeatureIRLayer(nn.Module):
def __init__(self, in_dim: int, out_dim: int) -> None:
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_dim, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
class TokenDownLayer(nn.Module):
def __init__(self, shape) -> None:
super().__init__()
self.dwn = nn.Sequential(
nn.AdaptiveAvgPool2d(shape)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, num_tokens, c = x.shape
h = int(math.sqrt(num_tokens))
assert h * h == num_tokens
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
x = self.dwn(x)
x = x.flatten(2).transpose(1, 2)
return x
class PosInjectLayer(nn.Module):
# https://github.com/Meituan-AutoML/Twins/blob/main/gvt.py
def __init__(self, in_dim: int, out_dim: int, stride: int = 1) -> None:
super().__init__()
self.peg = nn.Sequential(
nn.Conv2d(in_dim, out_dim, 3, stride, 1, bias=True, groups=out_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, num_tokens, c = x.shape
h = int(math.sqrt(num_tokens))
assert h * h == num_tokens
cnn_feat = x.transpose(1, 2).view(b, c, h, h)
x = self.peg(cnn_feat) + cnn_feat
x = x.flatten(2).transpose(1, 2)
return x
class LDPNetV2Projector(nn.Module):
def __init__(self, config=None):
super().__init__()
inc, ouc = config.mm_hidden_size, config.hidden_size
self.mlp = FeatureIRLayer(inc, ouc)
self.dwn = TokenDownLayer((12, 12))
self.peg = PosInjectLayer(ouc, ouc, stride=1)
def forward(self, x):
x = self.mlp(x)
x = self.dwn(x)
x = self.peg(x)
return x
-
S^2 wrapper
- 论文:arxiv.org/abs/2403.13…
- 代码:github.com/bfshi/scali… (实现很简洁,方便复用)
注意图中的维度变换,很清晰
[B, N, D] 在特征维度D上concat,而不是token length N上concat