论文标题 FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation code link
FRESCO方法是不需要对模型进行重新训练,通过修改attention模块以及利用光流方法从而保证视频的稳定性,这一点实际上与ControlVedio模型是结构非常详细的。这里面有很多值得借鉴的内容,已经有人对论文的整体内容做过解读link,本文主要来分析他们是如何实现对应的模块的。
程序结构
FRESCO程序是在
controlnet模型的基础上进行修改的,我们可以看到模型上面有feature optimization以及FreSCo-guided attention模块。其中feature optimization是用来增强视频的时间和空间上的稳定性,而FreSCo-guided attention模块用不同的attention模块将信息组合起来。
feature optimization
论文作者将优化目标定义为 ,我们的优化目标有空间部分,以及时间部分,对应的公式分别为:
我们首先来看一下空间部分的程序是怎样实现的:
空间一致性
作者希望通过比较原始的feature f 添加了noise之后的feature 去进行比较,来提高时空一致性。因此需要去分别计算f 以及 地余弦相似度。我们需要首先进行归一化处理得到 ,随后计算gram矩阵,也就是与余弦相似度。
# spatial consistency loss
if attention_probs is not None and intra_weight > 0:
cs_vector = rearrange(cs, "b f c h w -> (b f) (h w) c")
#attention_scores = torch.bmm(cs_vector, cs_vector.transpose(-1, -2))
#cs_attention_probs = attention_scores.softmax(dim=-1)
cs_vector = cs_vector / ((cs_vector ** 2).sum(dim=2, keepdims=True) ** 0.5)
cs_attention_probs = torch.bmm(cs_vector, cs_vector.transpose(-1, -2))
tmp = F.l1_loss(cs_attention_probs, attention_probs) * intra_weight
loss = tmp + loss
首先是对数据进行维度变换为,即将每个batch和帧中的特征拉平为一个向量,为了方便后续去计算相似矩阵。
cs_vector = rearrange(cs, "b f c h w -> (b f) (h w) c")
随后我们对矩阵cs_vector去进行归一化处理,并使用torch.bmm计算得到相似性矩阵cs_attention_probs。
cs_vector = cs_vector / ((cs_vector ** 2).sum(dim=2, keepdims=True) ** 0.5)
cs_attention_probs = torch.bmm(cs_vector, cs_vector.transpose(-1, -2))
最后使用对相似性矩阵以及关联矩阵计算,这里面的关联矩阵attention_probs
tmp = F.l1_loss(cs_attention_probs, attention_probs) * intra_weight
loss = tmp + loss
随后是先将其添加noise后的 的计算。 我们首先通过pipe.prepare_latents初始化一个随机的latents,随后通过对原始图像使用vae编码得到了latent_x0,此时latents_x0是无噪声的,我们通过noise_scheduler向latent_x0中添加噪音得到了带噪声的结果。
latents = pipe.prepare_latents(...)
latent_x0 = pipe.vae.config.scaling_factor *pipe.vae.encode(imgs.to(pipe.unet.dtype)).latent_dist.sample()
latents = noise_scheduler.add_noise(latent_x0, latents, timestep).detach()
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
我们将带噪声的latent_model_input输入到Unet中最终得到了model_output,model_output中第一个元素为denoised_image我们并不需要这个结果,所以后续会去掉。
model_output = pipe.unet(latent_model_input, timestep, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=None, return_dict=False)
我们对model_output出去第一个元素,来处理剩余的结果。首先是对数据转化为 的形式,这与上文所提到的差出了一个f。 随后是归一化,使用torch.bmm去计算gram矩阵,并将结果保存在correlation_matrix,correlation_matrix在去做归一化之后得到了上文中提到的attention_probs
correlation_matrix = []
for tmp in model_output[1:]:
latent_vector = rearrange(tmp, "b c h w -> b (h w) c")
latent_vector = latent_vector / ((latent_vector ** 2).sum(dim=2, keepdims=True) ** 0.5)
attention_probs = torch.bmm(latent_vector, latent_vector.transpose(-1, -2))
correlation_matrix += [attention_probs.detach().clone().to(torch.float32)]
del attention_probs, latent_vector, tmp
del model_output
时间一致性
我们希望两个相邻帧之间应该物体是一致的,作者使用了光流场方法,这里面分别用到了bwd_flow_,fwd_flow_,bwd_occ_,fwd_occ_分别对应反向光流场,反向遮挡区域,正向光流场,正向遮挡区域,这里面包括了相邻帧之间的运动信息(光流场)和遮挡信息,并且已经通过插值方法匹配输入数据的分辨率和格式
# unify resolution
if flows is not None and occs is not None:
scale = sample.shape[2] * 1.0 / flows[0].shape[2]
kernel = int(1 / scale)
bwd_flow_ = F.interpolate(flows[1] * scale, scale_factor=scale, mode='bilinear').repeat(unet_chunk_size,1,1,1)
bwd_occ_ = F.max_pool2d(occs[1].unsqueeze(1), kernel_size=kernel).repeat(unet_chunk_size,1,1,1) # 2(N-1)*1*H1*W1
fwd_flow_ = F.interpolate(flows[0] * scale, scale_factor=scale, mode='bilinear').repeat(unet_chunk_size,1,1,1)
fwd_occ_ = F.max_pool2d(occs[0].unsqueeze(1), kernel_size=kernel).repeat(unet_chunk_size,1,1,1) # 2(N-1)*1*H1*W1
# match frame 0,1,2,3 and frame 1,2,3,0
reshuffle_list = list(range(1,video_length))+[0]
完成光流场的计算之后,我们有c1,c2分别对应到顺序帧和下一帧,例如:[0 1 2 3] ~ [1 2 3 0],由此得到当前帧c1[i]与下一帧c2[i]。
随后使用gmflow中的flow_warp根据当前帧和光流场去计算扭曲后的图像,并最后将c1和c2和遮挡信息结合得到最终的loss。
if optimize_temporal and flows is not None and occs is not None:
c1 = rearrange(cs[:,:], "b f c h w -> (b f) c h w")
c2 = rearrange(cs[:,reshuffle_list], "b f c h w -> (b f) c h w")
warped_image1 = flow_warp(c1, bwd_flow_)
warped_image2 = flow_warp(c2, fwd_flow_)
loss = (abs((c2-warped_image1)*(1-bwd_occ_)) + abs((c1-warped_image2)*(1-fwd_occ_))).mean() * 2
最后我们将时间和空间部分的loss加到一起,调用optimizer来优化原始的输入cs。
FreSCo-guided attention
这里的attention模块有点奇怪,后续再去写