MagicDrive可以将文字描述等信息通过Diffusion生成模型转化为自动驾驶中所需要的BEV视图,而且其是目前是效果比较好的模型中唯一开源的,并且注释非常详细,因此我们来学习一下这里面的内容。
信息来源: code home_page 作者解读
MagicDrive主要框架
MagicDrive本质上是基于Bevgen模型的思路,通过将更加适合自动驾驶的Road Map, Object Box, Prompt等内容作为模型的输入,在nuScenes数据集上去做训练。 因此这里面主要包括两部分的内容:模型输入的编码(Road Map之类) 以及对模型的修改从而满足:多视角一致性和时空的一致性(后面会具体是介绍)。接下来我们看一下对应的模型是怎样工作和实现的
程序结构
我们首先来看 pipeline_bev_controlnet.py这个文件,作者在这里完成的注释也是非常详细的,我们一段一段地分析一下具体的内容。
首先是prompt对应到batch_size,例如我们写出两个prompt : {rainy, city} {sunny, rural road}就对应到两个不同的场景,对应的batch_size = 2,将其放到同一批次可以重复利用编码好的Road Map等内容,而不用重复计算,。
# 2. Define call parameters
# NOTE: we get batch_size first from prompt, then align with it.
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
随后是对判定是否做cfg操作,对cfg的解释可以参考link。在这里我们对N_cam进行硬编码为6,也就是对应到BEV视角中的六个不同的视角。随后是对相加参数camera_param进行编码,其目的是为了生成生成无条件的相机参数(具体对相机的编码可以参考后面的内容)。
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
### BEV, check camera_param ###
if camera_param is None:
# use uncond_cam and disable classifier free guidance
N_cam = 6 # TODO: hard-coded
camera_param = self.controlnet.uncond_cam_param((batch_size, N_cam))
do_classifier_free_guidance = False
### done ###
这里我们对prompt进行编码,这里使用的_encode_prompt是在diffusers里面的官方库实现的,使用的编码器为CLIP
# 3. Encode input prompt
# NOTE: here they use padding to 77, is this necessary?
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) # (2 * b, 77 + 1, 768)
这里是对控制图像进行转换,保证其都为统一的格式。 这一步对应到的是Road Map的转换
# 4. Prepare image
# NOTE: if image is not tensor, there will be several process.
assert not self.control_image_processor.config.do_normalize, "Your controlnet should not normalize the control image."
image = self.prepare_image(
image=image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
) # (2 * b, c_26, 200, 200)
if use_zero_map_as_unconditional and do_classifier_free_guidance:
# uncond in the front, cond in the tail
_images = list(torch.chunk(image, 2))
_images[0] = torch.zeros_like(_images[0])
image = torch.cat(_images)
随后是配置对应的timesteps,以及去生成所需要去噪的初始latents。
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents, # will use if not None, otherwise will generate
) # (b, c, h/8, w/8) -> (bs, 4, 28, 50)
接下来我们对模型的输入去进行编码,我们在这一步需要去看一看prompt, camrea, object box以及Road Map是如何编码进去的。
- 首先对于文字prompt已经是在先前被编码为
prompt_embeds. - 随后对于相机信息,在之前被编码为
camera_param, - 对于Road Map部分,这里对应的
image。 - 针对
object box进行编码,主要是去使用add_uncond_to_kwargs这个函数将object box信息放到了bev_controlnet_kwargs中,具体的编码方法add_uncond_to_kwargs我们会在后面详细介绍。
# 7. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
###### BEV: here we reconstruct each input format ######
assert camera_param.shape[0] == batch_size, \
f"Except {batch_size} camera params, but you have bs={len(camera_param)}"
N_cam = camera_param.shape[1]
latents = torch.stack([latents] * N_cam, dim=1) # bs, 6, 4, 28, 50
# prompt_embeds, no need for b, len, 768
# image, no need for b, c, 200, 200
camera_param = camera_param.to(self.device)
if do_classifier_free_guidance and not guess_mode:
# uncond in the front, cond in the tail
_images = list(torch.chunk(image, 2))
kwargs_with_uncond = self.controlnet.add_uncond_to_kwargs(
camera_param=camera_param,
image=_images[0], # 0 is for unconditional
max_len=bbox_max_length,
**bev_controlnet_kwargs,
)
kwargs_with_uncond.pop("max_len", None) # some do not take this.
camera_param = kwargs_with_uncond.pop("camera_param")
_images[0] = kwargs_with_uncond.pop("image")
image = torch.cat(_images)
bev_controlnet_kwargs = move_to(kwargs_with_uncond, self.device)
###### BEV end ######
对于去噪过程,对应的代码比较长,我们来看一下核心的内容:首先是使用controlnet将所有的信息都去进行编码为encoder_hidden_states_with_cam,随后将信息输入到Unet中去得到预测的噪音,最后是对letents进行修改。
encoder_hidden_states_with_cam = self.(
controlnet_latent_model_input,
controlnet_t,
camera_param, # for BEV
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image,
conditioning_scale=controlnet_conditioning_scale,
guess_mode=guess_mode,
return_dict=False,
**bev_controlnet_kwargs, # for BEV
)
for i, t in enumerate(timesteps):
noise_pred = self.unet(
latent_model_input, # may with unconditional
t,
encoder_hidden_states=encoder_hidden_states_with_cam,
**additional_param, # if use original unet, it cannot take kwargs
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
).sample
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
MagicDrive注意力机制模块
在了解了整体的模型框架后,我们来看一下他们是怎样去保证一致性,这里的一致性包括多视角的一致性(BEV的六个视角的一致性),以及视频生成中的帧与帧之间的一致性(也就是对应的时空一致性),不过这里MagicDrive模型的时间一致性比较弱,视频中的物体变换还是非常明显,我们可以参考其他的论文中的方法对其进行改进。
多视角一致性
对于多视角一致性,我们很直观地可以想到将BEV的六个视角的图像使用cross-attention关联起来,这里MagicDrive做了实现,发现在生成图像时将左右两边的图像关联进来的效果最好。下面是对应的公示内容:
下面是对应的程序实现,可以观察到有两种处理
neighbor的方式,分别是:add以及contact,作者测试了两种效果,一般而言是add的效果会更好一些。 这里面还有一个neighboring_view_pair,在代码段最下面,就是将每个位置的左右邻居编码了进去。
随后我们看到了两个循环从neighboring_view_pair中实现了将相邻的hidden_states使用"add"方法进行融合。例如: key = 1, values = [0,2],hidden_states_in1就会存入两份norm_hidden_states[:,1],同样地hidden_states_in2会存入norm_hidden_states[:,0]以及norm_hidden_states[:,2]。从而实现了左右邻居信息的融合
def _construct_attn_input(self, norm_hidden_states):
B = len(norm_hidden_states)
# reshape, key for origin view, value for ref view
hidden_states_in1 = []
hidden_states_in2 = []
cam_order = []
if self.neighboring_attn_type == "add":
for key, values in self.neighboring_view_pair.items():
for value in values:
hidden_states_in1.append(norm_hidden_states[:, key])
hidden_states_in2.append(norm_hidden_states[:, value])
cam_order += [key] * B
# N*2*B, H*W, head*dim
hidden_states_in1 = torch.cat(hidden_states_in1, dim=0)
hidden_states_in2 = torch.cat(hidden_states_in2, dim=0)
cam_order = torch.LongTensor(cam_order)
elif self.neighboring_attn_type == "concat":
for key, values in self.neighboring_view_pair.items():
hidden_states_in1.append(norm_hidden_states[:, key])
hidden_states_in2.append(torch.cat([
norm_hidden_states[:, value] for value in values
], dim=1))
cam_order += [key] * B
# N*B, H*W, head*dim
hidden_states_in1 = torch.cat(hidden_states_in1, dim=0)
# N*B, 2*H*W, head*dim
hidden_states_in2 = torch.cat(hidden_states_in2, dim=0)
cam_order = torch.LongTensor(cam_order)
neighboring_view_pair:
0: [5, 1]
1: [0, 2]
2: [1, 3]
3: [2, 4]
4: [3, 5]
5: [4, 0]
在得到了hidden_states_in1,hidden_states_in2,我们将其输入到一个attention模块中,完成了从而完成了这部分的计算。self.attn4为普通的Attention模块
attn_raw_output = self.attn4(
hidden_states_in1,
encoder_hidden_states=hidden_states_in2,
**cross_attention_kwargs,
)
时空一致性
MagicDrive只是简单地尝试了一下,效果并不是很好,视频画面不够稳定。
首先是看下面的原理图,可以看到为为了生成一个视角的图像,我们不仅要有左右视角的信息,还要有key frame以及prev frame这两部分的信息。这一部分是采用temp attn的方式来实现的,接下来我们看一下具体的程序。
这里的方法与之前非常一致,就是提取对应的第一帧作为关键帧,将第一帧的key value提取出来替换掉后续的图像的key value,这里面自定义了
rearrange_3函数来转换对应的shape。
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# Sparse Attention
if not is_cross_attention:
video_length = key.size()[0] // self.batch_size
first_frame_index = [0] * video_length
# rearrange keys to have batch and frames in the 1st and 2nd dims respectively
key = rearrange_3(key, video_length)
key = key[:, first_frame_index]
# rearrange values to have batch and frames in the 1st and 2nd dims respectively
value = rearrange_3(value, video_length)
value = value[:, first_frame_index]
# rearrange back to original shape
key = rearrange_4(key)
value = rearrange_4(value)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
def rearrange_3(tensor, f):
F, D, C = tensor.size()
return torch.reshape(tensor, (F // f, f, D, C))
改进内容
有好几篇的改进内容,对应的效果好了很多,不过目前都还没有开源,可以去等他们最终的效果。
编码函数
介绍对于object box的编码,camera的编码。 等后续再去完成