前情提要
本文是传知代码平台中的相关前沿知识与技术的分享~
接下来我们即将进入一个全新的空间,对技术有一个全新的视角~
本文所涉及所有资源均在传知代码平台可获取
以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦!!!
以下内容干货满满,跟上步伐吧~
💡本章重点
- 当下最牛的图像压缩算法
🍞一. 概述
首先,这篇文章的出发点就是图像压缩最本源的目的,就是探索如何在相同的码率下获得更高质量的重建图像,或者说在得到的重建图像质量一样的情况下,如何进一步节省码率。
然后作者就站在前人做的利用深度学习压缩的基础上思考,有一批人使用CNN的方法,可以很好地降低空间冗余度,然后捕获图像的空域结构;另一批人使用Transformer的结构,来捕捉图像中长距离的空间依赖关系。于是作者就想,能不能把这两种方法做一个结合,做这么样一个结构,使其同时具备这两种算法的优点。于是就在此基础上,作者提出了本文的方法。
🍞二. 先验知识
在这一部分,我结合图文向大家解释一下基于深度学习进行图像压缩的基本框架流程,便于进一步理解本文方法。
先给出示意图如下:
首先是原图经过编码器得到一个潜在的表示y,就可以类比传统图像压缩里稀疏化的变换,只不过这里用一个可以学习的变换器来代替之前的人工设计的变换方法。
然后得到y之后,我们也是将y进行量化,熵编码,然后打包成码流进行传输。熵编码的部分呢,就是通过学习y的一些特征,来指导熵编码器对量化后的y进行更加精确的熵估计,一般是用基于高斯分布的算术编码器来进行熵编码,所以可以看到,这一部分学习的参数,往往也是均值和方差等等。
大家做的工作一般也是集中于如何改进这个编码器的结构,得到更加合理的潜在表示,然后一方面就是对熵编码器这里做一些工作,想方设法能使熵编码器对量化后的y进行更加准确地估计,从而做到节省码率。
🍞三.LIC-TCM算法
亮点一:TCM块
那么我们首先来看一下TCM块的设计。
在这个结构里大家可以看到作者是使用了这个Swin Transformer 块和残差块来实现的一个两个方法的融合。
具体过程:输入的特征向量经过一个11的卷积,我们知道11卷积能够很好的糅合各通道之间的信息,然后下一步就是在通道维度对这个特征向量做一个切割,分别送入到Transformer块和残差块里进行学习。采用这种并行式的处理,一方面可以减小参数量,另一方面,能够分别学习各自擅长学习的特征。然后对各自得到的结果向量,先进行一个Concatenate,然后同样经过一个1*1卷积,对其各自的特征进行一个交互。
需要注意的是作者并没有将这个向量直接作为输出,而是进行了一个双阶段的设计,我认为这个也是以Swin-Transformer为启发,可以更好地对特征进行融合。
然后整体的表达式就是下面这三个式子。
亮点二:熵模型设计
接下来是作者的第二部分工作,提出了一种熵模型。
可以看出常规的熵模型是将整个y送入到一个提取超先验信息的网络当中,然后得到熵编码需要的参数。而这里的熵模型可以理解为将y沿着通道维度拆分成为多段,分别进行熵编码,最后再进行拼接。这么做的好处,不仅可以利用GPU的并行处理,而且还能通过上一段解码出的y进行指导,获得更加准确的估计。这个熵模型的想法是前人的工作,作者主要做的是在这样一个个参数估计的网络里引入了一种注意力机制。
基于Swin-Transformer块的注意力机制。可以看到SWAttention模块就是下图这样的一个结构,与前面TCM block的设计思路类似,作者都是想将非局部信息和局部信息做一个很好的结合,于是就在获取注意力的时候加入这样一个Swin-Transformer的基本块,来映射一些非局部的信息。
🍞四.核心代码解读
这里主要介绍模型部分的代码,对于一些基本卷积操作或训练时的基础设置等不做赘述。
- 首先是TCM模块部分
def forward(self, x):
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
conv_x = self.conv_block(conv_x) + conv_x
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
trans_x = self.trans_block(trans_x)
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
x = x + res
return x
这是TCM模块的前向过程,也是本文的核心之一。可以看到与论文表述一致,通过对输入特征split操作之后,分别入到残差块和Swin-Transformer块里,再做拼接操作。
- Swin-Transformer块的代码如下:
def forward(self, x):
resize = False
if (x.size(-1) <= self.window_size) or (x.size(-2) <= self.window_size):
padding_row = (self.window_size - x.size(-2)) // 2
padding_col = (self.window_size - x.size(-1)) // 2
x = F.pad(x, (padding_col, padding_col+1, padding_row, padding_row+1))
trans_x = Rearrange('b c h w -> b h w c')(x)
trans_x = self.block_1(trans_x)
trans_x = self.block_2(trans_x)
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
if resize:
x = F.pad(x, (-padding_col, -padding_col-1, -padding_row, -padding_row-1))
return trans_x
其中的Block块为内嵌的Transformer块:
def forward(self, x):
x = x + self.drop_path(self.msa(self.ln1(x)))
x = x + self.drop_path(self.mlp(self.ln2(x)))
return x
- 最后整个模型的结构如下:
其中,熵模型(即右侧红框部分)采用基于通道划分和窗注意力机制,其训练过程中的前向代码如下,估计出高斯建模的均值和方差。
for slice_index, y_slice in enumerate(y_slices):
support_slices = (y_hat_slices if self.max_support_slices < 0 else y_hat_slices[:self.max_support_slices])
mean_support = torch.cat([latent_means] + support_slices, dim=1)
mean_support = self.atten_mean[slice_index](mean_support)
mu = self.cc_mean_transforms[slice_index](mean_support)
mu = mu[:, :, :y_shape[0], :y_shape[1]]
mu_list.append(mu)
scale_support = torch.cat([latent_scales] + support_slices, dim=1)
scale_support = self.atten_scale[slice_index](scale_support)
scale = self.cc_scale_transforms[slice_index](scale_support)
scale = scale[:, :, :y_shape[0], :y_shape[1]]
scale_list.append(scale)
_, y_slice_likelihood = self.gaussian_conditional(y_slice, scale, mu)
y_likelihood.append(y_slice_likelihood)
y_hat_slice = ste_round(y_slice - mu) + mu
# if self.training:
# lrp_support = torch.cat([mean_support + torch.randn(mean_support.size()).cuda().mul(scale_support), y_hat_slice], dim=1)
# else:
lrp_support = torch.cat([mean_support, y_hat_slice], dim=1)
lrp = self.lrp_transforms[slice_index](lrp_support)
lrp = 0.5 * torch.tanh(lrp)
y_hat_slice += lrp
y_hat_slices.append(y_hat_slice)
- 编码器部分则采用算术编码器,具体调用代码如下:
encoder.encode_with_indexes(symbols_list, indexes_list, cdf, cdf_lengths, offsets)
y_string = encoder.flush()
以上内容即为LIC-TCM模型的核心代码讲解。
🫓总结
综上,我们基本了解了“一项全新的技术啦” :lollipop: ~~
恭喜你的内功又双叒叕得到了提高!!!
感谢你们的阅读:satisfied:
后续还会继续更新:heartbeat:,欢迎持续关注:pushpin:哟~
:dizzy:如果有错误❌,欢迎指正呀:dizzy:
:sparkles:如果觉得收获满满,可以点点赞👍支持一下哟~:sparkles:
【传知科技 -- 了解更多新知识】