ControlVedio作为经典的视频转换模型,基于Controlnet模型,仅通过修改Unet的结构就可以得到比较流畅的视频生成效果。当然对应的视频生成效果已经被SORA暴打,不过还是非常有借鉴价值的,我们接下来看一下这里面的代码是怎样完成的。
模型的细节介绍可以参考 link
模型介绍
首先是对应的整体模型框架,这里面作者提出了将二维的注意力机制转化为fully Cross-Frame Attention,相当于将2D的注意力机制转化为3D的注意力机制效果。
随后是一个平滑的技术,我们暂时忽略这个内容。
最后介绍了如何提取关键帧以及应用关键帧生成更长的视频。通过先将我们所需要的视频的而关键帧生成出来,随后以关键帧作为参考来生成关键帧中间的插帧视频。
程序结构
我们首先对其中的attention模块进行分析。FullyFrameAttention以及BasicTransformerBlock为主要的修改内容。我们来看一下作者是怎样实现对应的程序的。我们diffusers源代码中的内容进行对比,来更好地观察修改的代码位置。
class BasicTransformerBlock(nn.Module):
def __init__(...):
# Fully
self.attn1 = FullyFrameAttention(...)
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
# Cross-Attn
if cross_attention_dim is not None:
self.attn2 = CrossAttention(...)
else:
self.attn2 = None
在初始化的过程中,和diffusers代码相比较,self.attn1模块由Attention变化为了FullyFrameAttention模块,也是对应到了上图中的右侧attention部分.我们来看一下其具体的结构是怎样的。 我们主要看其中的forward函数是怎样定义的
def forward(...):
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states) # (bf) x d(hw) x c
dim = query.shape[-1]
# All frames
query = rearrange(query, "(b f) d c -> b (f d) c", f=video_length)
query = self.reshape_heads_to_batch_dim(query)
if self.added_kv_proj_dim is not None:
raise NotImplementedError
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
if inter_frame:
key = rearrange(key, "(b f) d c -> b f d c", f=video_length)[:, [0, -1]]
value = rearrange(value, "(b f) d c -> b f d c", f=video_length)[:, [0, -1]]
key = rearrange(key, "b f d c -> b (f d) c",)
value = rearrange(value, "b f d c -> b (f d) c")
else:
# All frames
key = rearrange(key, "(b f) d c -> b (f d) c", f=video_length)
value = rearrange(value, "(b f) d c -> b (f d) c", f=video_length)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
# All frames
hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=video_length)
return hidden_states
在上面的程序中,我们可以看到最终的Q K V都被转化为了 b (f d) c的形状,其中b batch, f frame(就是输入的帧数), d=hw(高 * 宽), c 通道数。与普通的controlnt不同,hidden_states是f张图片的hidden_states,而并不只是单一图片的hidden_states。
query = self.to_q(hidden_states) # (bf) x d(hw) x c
dim = query.shape[-1]
# All frames
query = rearrange(query, "(b f) d c -> b (f d) c", f=video_length)
query = self.reshape_heads_to_batch_dim(query)
随后是对于key和value的值,对于 ALL frames的情况,与上文是完全相同的
我们来看一下对于inter_frame的情况。这种情况对应到长视频生成中,我们是使用首尾两帧作为key和value值,来将中间的帧进行渲染。对应的公式为:
在程序中,我们首先是将
key和value完全展开,并且提取首尾两帧[:, [0, -1]],随后将结果重新整理为b (f d) c格式,当然此时f = 2(首尾两帧)。这对应到前文所提到的长视频生成部分,我们通过先使用ALL frames部分生成key-frame之后,去使用iter_frame来将中间部分的帧去重新生成,这样可以降低中间的GPU显存的需求,从而间接地生成了长视频
if inter_frame:
key = rearrange(key, "(b f) d c -> b f d c", f=video_length)[:, [0, -1]]
value = rearrange(value, "(b f) d c -> b f d c", f=video_length)[:, [0, -1]]
key = rearrange(key, "b f d c -> b (f d) c",)
value = rearrange(value, "b f d c -> b (f d) c")
else:
# All frames
key = rearrange(key, "(b f) d c -> b (f d) c", f=video_length)
value = rearrange(value, "(b f) d c -> b (f d) c", f=video_length)
剩下的部分就是和controlnet源代码一致,添加线性层等。