FRESCO程序解读

527 阅读5分钟

论文标题 FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation code link

FRESCO方法是不需要对模型进行重新训练,通过修改attention模块以及利用光流方法从而保证视频的稳定性,这一点实际上与ControlVedio模型是结构非常详细的。这里面有很多值得借鉴的内容,已经有人对论文的整体内容做过解读link,本文主要来分析他们是如何实现对应的模块的。

程序结构

image.png FRESCO程序是在controlnet模型的基础上进行修改的,我们可以看到模型上面有feature optimization以及FreSCo-guided attention模块。其中feature optimization是用来增强视频的时间和空间上的稳定性,而FreSCo-guided attention模块用不同的attention模块将信息组合起来。

feature optimization

论文作者将优化目标定义为 f^ =argmin(Lspat (f)+Ltemp (f))\hat{f}\ = argmin(\mathcal{L}_{\text {spat }}(\mathbf{f}) +\mathcal{L}_{\text {temp }}(\mathbf{f}) ),我们的优化目标有空间部分,以及时间部分,对应的公式分别为:
Lspat (f)=λspat if~if~if~irf~ir22\mathcal{L}_{\text {spat }}(\mathbf{f})=\lambda_{\text {spat }} \sum_{i}\left\|\tilde{f}_{i} \tilde{f}_{i}^{\top}-\tilde{f}_{i}^{r} \tilde{f}_{i}^{r \top}\right\|_{2}^{2}

Ltemp (f)=iMii+1(fi+1wii+1(fi))1\mathcal{L}_{\text {temp }}(\mathbf{f})=\sum_{i}\left\|M_{i}^{i+1}\left(f_{i+1}-w_{i}^{i+1}\left(f_{i}\right)\right)\right\|_{1}

我们首先来看一下空间部分的程序是怎样实现的:

空间一致性

作者希望通过比较原始的feature f 添加了noise之后的feature frf^{r}去进行比较,来提高时空一致性。因此需要去分别计算f 以及 frf^{r}地余弦相似度。我们需要首先进行归一化处理得到 f~i\tilde{f}_{i},随后计算gram矩阵,也就是与余弦相似度f~if~i\tilde{f}_{i} \tilde{f}_{i}^{\top}

# spatial consistency loss
if attention_probs is not None and intra_weight > 0:
   cs_vector = rearrange(cs, "b f c h w -> (b f) (h w) c")
   #attention_scores = torch.bmm(cs_vector, cs_vector.transpose(-1, -2))
   #cs_attention_probs = attention_scores.softmax(dim=-1)
   cs_vector = cs_vector / ((cs_vector ** 2).sum(dim=2, keepdims=True) ** 0.5)
   cs_attention_probs = torch.bmm(cs_vector, cs_vector.transpose(-1, -2))
   tmp = F.l1_loss(cs_attention_probs, attention_probs) * intra_weight
   loss = tmp + loss

首先是对数据进行维度变换为(bf)hwc(bf)(hw)c,即将每个batch和帧中的特征拉平为一个向量,为了方便后续去计算相似矩阵。

cs_vector = rearrange(cs, "b f c h w -> (b f) (h w) c")

随后我们对矩阵cs_vector去进行归一化处理,并使用torch.bmm计算得到相似性矩阵cs_attention_probs

cs_vector = cs_vector / ((cs_vector ** 2).sum(dim=2, keepdims=True) ** 0.5)
cs_attention_probs = torch.bmm(cs_vector, cs_vector.transpose(-1, -2))

最后使用对相似性矩阵以及关联矩阵计算L1lossL_{1} loss,这里面的关联矩阵attention_probs

tmp = F.l1_loss(cs_attention_probs, attention_probs) * intra_weight
loss = tmp + loss

随后是先将其添加noise后的 frf^{r}的计算。 我们首先通过pipe.prepare_latents初始化一个随机的latents,随后通过对原始图像使用vae编码得到了latent_x0,此时latents_x0是无噪声的,我们通过noise_schedulerlatent_x0中添加噪音得到了带噪声的结果。

latents = pipe.prepare_latents(...)
latent_x0 = pipe.vae.config.scaling_factor *pipe.vae.encode(imgs.to(pipe.unet.dtype)).latent_dist.sample()
latents = noise_scheduler.add_noise(latent_x0, latents, timestep).detach()
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 

我们将带噪声的latent_model_input输入到Unet中最终得到了model_outputmodel_output中第一个元素为denoised_image我们并不需要这个结果,所以后续会去掉。

model_output = pipe.unet(latent_model_input, timestep, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=None, return_dict=False)

我们对model_output出去第一个元素,来处理剩余的结果。首先是对数据转化为 b(hw)cb (h w) c的形式,这与上文所提到的(bf)hwc(bf)(hw)c差出了一个f。 随后是归一化,使用torch.bmm去计算gram矩阵,并将结果保存在correlation_matrixcorrelation_matrix在去做归一化之后得到了上文中提到的attention_probs

correlation_matrix = []
for tmp in model_output[1:]:
    latent_vector = rearrange(tmp, "b c h w -> b (h w) c")
    latent_vector = latent_vector / ((latent_vector ** 2).sum(dim=2, keepdims=True) ** 0.5)
    attention_probs = torch.bmm(latent_vector, latent_vector.transpose(-1, -2))
    correlation_matrix += [attention_probs.detach().clone().to(torch.float32)]
    del attention_probs, latent_vector, tmp
del model_output   

时间一致性

我们希望两个相邻帧之间应该物体是一致的,作者使用了光流场方法,这里面分别用到了bwd_flow_,fwd_flow_,bwd_occ_,fwd_occ_分别对应反向光流场,反向遮挡区域,正向光流场,正向遮挡区域,这里面包括了相邻帧之间的运动信息(光流场)和遮挡信息,并且已经通过插值方法匹配输入数据的分辨率和格式

# unify resolution
if flows is not None and occs is not None:
    scale = sample.shape[2] * 1.0 / flows[0].shape[2]
    kernel = int(1 / scale)
    bwd_flow_ = F.interpolate(flows[1] * scale, scale_factor=scale, mode='bilinear').repeat(unet_chunk_size,1,1,1)
    bwd_occ_ = F.max_pool2d(occs[1].unsqueeze(1), kernel_size=kernel).repeat(unet_chunk_size,1,1,1) # 2(N-1)*1*H1*W1
    fwd_flow_ = F.interpolate(flows[0] * scale, scale_factor=scale, mode='bilinear').repeat(unet_chunk_size,1,1,1)
    fwd_occ_ = F.max_pool2d(occs[0].unsqueeze(1), kernel_size=kernel).repeat(unet_chunk_size,1,1,1) # 2(N-1)*1*H1*W1
    # match frame 0,1,2,3 and frame 1,2,3,0
reshuffle_list = list(range(1,video_length))+[0]

完成光流场的计算之后,我们有c1c2分别对应到顺序帧和下一帧,例如:[0 1 2 3] ~ [1 2 3 0],由此得到当前帧c1[i]与下一帧c2[i]。
随后使用gmflow中的flow_warp根据当前帧和光流场去计算扭曲后的图像,并最后将c1c2和遮挡信息结合得到最终的loss。

if optimize_temporal and flows is not None and occs is not None:
    c1 = rearrange(cs[:,:], "b f c h w -> (b f) c h w")
    c2 = rearrange(cs[:,reshuffle_list], "b f c h w -> (b f) c h w")
    warped_image1 = flow_warp(c1, bwd_flow_)
    warped_image2 = flow_warp(c2, fwd_flow_)
    loss = (abs((c2-warped_image1)*(1-bwd_occ_)) + abs((c1-warped_image2)*(1-fwd_occ_))).mean() * 2

最后我们将时间和空间部分的loss加到一起,调用optimizer来优化原始的输入cs

FreSCo-guided attention

这里的attention模块有点奇怪,后续再去写