RepVGG

111 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

arxiv.org/abs/2101.03…

其中a和b是控制宽度的,a是控制stage1-4的宽度,b控制stage5的宽度

还有个-Bg,g是指用了组卷积,g后面的数值代表组的数量。注意不是所有卷积层都换成了组卷积,其中是在2,4,6,8,10,12,14,16,18,20,22,26个卷积处使用,从stage2开始算,从1开始数,例如RepVGG-B,stage2重复4次,就是第2个和第4个使用

精度较高,速度几乎比EfficientNet V1快3倍

网络亮点

  • 核心:结构重参数化

网络结构

block

a为stride=2时的结构,b为stride=1时结构

  1. 更快
    并行度方面,看其中一个block,走3×3分支的要慢,所以别的要等它的结果
    MAC,内存访问成本,多分支这种每个分支都要访问一次输入特征图
    算子角度,多分支的要启动多次算子,而单分支的少
  2. 省内存
  3. 更灵活

结构重参数化

把BN层的参数移到卷积层中

下面是正常卷积过程:

将1×1卷积转换成3×3卷积

将BN转换成3×3卷积

多分支融合

所以就是将三个分支的卷积核的参数相加,偏置加在一起

上图I为输入特征图,K为卷积核参数,B为偏置

from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn


def main():
    torch.random.manual_seed(0)

    f1 = torch.randn(1, 2, 3, 3)

    module = nn.Sequential(OrderedDict(
        conv=nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False),
        bn=nn.BatchNorm2d(num_features=2)
    ))

    module.eval()

    with torch.no_grad():
        output1 = module(f1)
        print(output1)

    # fuse conv + bn
    kernel = module.conv.weight 
    running_mean = module.bn.running_mean
    running_var = module.bn.running_var
    gamma = module.bn.weight
    beta = module.bn.bias
    eps = module.bn.eps
    std = (running_var + eps).sqrt()
    t = (gamma / std).reshape(-1, 1, 1, 1)  # [ch] -> [ch, 1, 1, 1]
    kernel = kernel * t
    bias = beta - running_mean * gamma / std
    fused_conv = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=1, padding=1, bias=True)
    fused_conv.load_state_dict(OrderedDict(weight=kernel, bias=bias))

    with torch.no_grad():
        output2 = fused_conv(f1)
        print(output2)

    np.testing.assert_allclose(output1.numpy(), output2.numpy(), rtol=1e-03, atol=1e-05)
    print("convert module has been tested, and the result looks good!")


if __name__ == '__main__':
    main()