复旦大学提出基于Mamba的轨迹预测模型DeMo: 将运动预测解耦为方向意图和动态状态
Abstract
准确的交通主体运动预测对于在动态变化环境中确保自动驾驶系统的安全性和效率至关重要。主流方法采用“一查询一轨迹”的范式,其中每个查询对应一个唯一的轨迹,用于预测多模态轨迹。尽管这种方法简单且有效,但由于缺乏对未来轨迹的详细表示,可能会导致次优结果,因为主体状态会随时间动态演变。为了解决这个问题,我们提出了DeMo框架,它将多模态轨迹查询解耦为两种类型:模式查询(mode queries),用于捕捉不同的方向意图;以及状态查询(state queries),用于跟踪主体随时间变化的动态状态。通过利用这种格式,我们分别优化了轨迹的多模态性和动态演化特性。随后,模式查询和状态查询结合起来,以获得对轨迹的全面而详细的表示。为实现这些操作,我们还引入了Attention(注意力机制)和Mamba技术的结合,用于全局信息聚合和状态序列建模,充分利用各自的优势。在Argoverse 2和nuScenes基准数据集上的大量实验表明,我们的DeMo在运动预测中达到了最先进的性能。
欢迎加入自动驾驶实战群
Introduction
运动预测能够使自动驾驶车辆预测周围主体的运动,并影响自车的行为,为自车的行动提供参考和条件。这对于维护安全性和可靠性至关重要,使得车辆能够理解驾驶环境的动态变化并做出经过计算的决策。该任务的挑战和复杂性来自于多种因素,包括不可预测的路况、交通参与者的不同运动模式,以及同时分析被观测主体状态和道路地图的必要性。
研究社区在驾驶场景表示和轨迹解码的范式上取得了显著进展。这些方法主要遵循了一种从检测中借鉴的模式,即“一查询一轨迹”的范式。该范式使用多个查询来表示不同的预测轨迹,覆盖了不同运动意图的可能性。尽管有效,但这些方法只能大致提供一个方向,并通过一次性方式收集周围环境来生成各种轨迹的路径点,忽视了轨迹与场景的详细关系。这种缺乏对轨迹的具体表示以及与周围环境和其他主体的时空交互,可能导致在不同时间步的准确性和一致性下降。
为了解决这个问题,我们提出了一个名为DeMo的新框架,提供了多模态轨迹的详细表示。具体而言,我们将预测查询解耦为两种类型:除了原来的运动模式查询(mode queries),用于捕捉不同的方向意图外,我们还引入了动态状态查询(state queries),用于跟踪未来轨迹中主体在不同时间步的动态状态。这种方法使我们能够在框架中实现全面的查询表示。模式查询和状态查询分别通过模式定位模块(Mode Localization Module)和状态一致性模块(State Consistency Module)进行处理。这些模块使查询能够与周围环境及彼此之间进行明确的交互,从而显著优化未来轨迹的方向精度和时间一致性。随后,通过我们的混合耦合模块(Hybrid Coupling Module),将两种类型的查询整合在一起,实现对未来轨迹的全面表示。由于轨迹状态具有顺序特性,Mamba特别适合用于建模动态状态的时间一致性。因此,我们在模块中结合了Attention和Mamba,以有效聚合全局信息并建模状态序列,充分利用这两种技术的优势。
3.Method
在本节中,我们介绍了DeMo框架,该框架利用解耦的模式查询和状态查询来预测未来轨迹中的方向意图和动态状态。我们还采用了结合Attention和Mamba的混合架构,并引入了两个辅助损失进行特征建模。
3.1 问题定义
给定高清地图(HD map)和驾驶场景中的主体,运动预测旨在为感兴趣的主体预测未来轨迹。高清地图由多个车道或交叉路口的折线组成,而主体是交通参与者,如车辆和行人。为了将这些元素转化为易于处理和学习的输入,我们采用了一种流行的向量化表示,参考文献。具体而言,地图是通过将每条线段分割成若干较短的段生成的,其中、和分别表示地图折线的数量、分段数和特征通道。我们将主体的历史信息表示为
,其中
和分别是主体数量、历史时间戳和运动状态(如位置、航向角、速度)。此外,感兴趣主体的未来轨迹
是估计的目标,
、分别表示选择的主体数量和未来时间戳。
3.2 场景上下文编码
给定主体的向量化表示A和高清地图M,我们首先分别使用个体编码器对它们进行处理。具体来说,我们使用基于PointNet的折线编码器,对地图表示MMM进行处理,生成地图特征。对于主体A,我们用多个单向Mamba块替换了通用的Transformer或RNN,这些块在序列编码方面更高效,以便在当前时间之前聚合历史轨迹特征
。随后,将它们拼接在一起形成场景上下文特征
,并进一步传递给Transformer编码器,以进行内部交互学习。整体过程可表述为:
3.3 使用解耦查询的轨迹解码
在获得场景上下文特征后,我们的目标是基于提出的解耦查询为每个感兴趣的主体解码多模态未来轨迹。如图2所示,解码网络由状态一致性模块(State Consistency Module)组成,该模块增强动态未来状态查询的一致性和准确性;模式定位模块(Mode Localization Module)学习不同的运动模式;混合耦合模块(Hybrid Coupling Module)集成解耦查询并生成最终输出。以下是这些组件的详细描述。
动态状态一致性
考虑到未来轨迹的递归性和因果关系,我们提出将其表示为跨时间步的动态状态序列,既独立又相互关联。为了保留精确的时间信息,状态查询通过MLP模块初始化,用于实时差异。值得注意的是,时间步可以与不同,以平衡效果和效率,特别是在预测长期未来轨迹或更高频率的未来轨迹时。然后,使用状态一致性模块增强状态查询的一致性并聚合特定的场景上下文,该过程可表述为:
具体来说,首先应用交叉注意力,使状态查询与场景上下文进行交互,随后使用Mamba块以线性时间复杂度建模序列关系。同时,为了考虑后方状态查询对前方状态查询的影响,我们采用双向Mamba进行前后扫描。此外,使用简单的MLP模块将状态查询解码为单个未来轨迹,以明确监督时间一致性。
方向意图定位
模式查询表示不同的运动模式,每个查询负责解码K条轨迹之一。我们使用模式定位模块(Mode Localization Module)定位潜在的方向意图,如下所示:
为了学习空间运动,使用了两个多头注意力(Multi-Head Attention)块来启用模式查询之间以及与场景上下文的交互。此外,我们还使用简单的MLP解码未来轨迹和概率。同样,我们引入了另一个辅助监督,以赋予模式查询不同的运动意图。
混合查询耦合
为了结合动态状态和方向意图,我们简单地将和相加,形成混合时空查询。然后,使用混合耦合模块进一步处理,以产生未来轨迹的全面表示,如下所示:
除了用于与场景上下文、模式和时间状态之间进行交互的Attention和Mamba模块外,我们还引入了一个混合自注意力层,该层连接时间和模式之间的查询,增强了预测轨迹的多样性。此模块中的特征维度变化如图2所示。最终的预测结果通过MLP解码生成轨迹位置和概率。
3.4 训练损失
DeMo通过三种组成损失以端到端的方式进行训练。主要采用回归损失和分类损失,用于监督预测轨迹的准确性及其相关的概率得分。此外,我们引入了两个辅助损失,分别是针对时间状态中间特征的和针对运动模式的。前者用于增强动态状态在各个时间步之间的一致性和因果关系,而后者赋予模式以明确的方向意图。总体损失LLL是这些各自损失的组合,并赋予了相等的权重,公式如下所示:
我们采用交叉熵损失来对概率得分进行分类,并使用Smooth-L1损失来处理轨迹回归任务。采用“胜者全得”策略,仅优化与真实值平均预测误差最小的最佳预测。
4.Experiment
4.1 与现有技术的比较
我们首先在Argoverse 2运动预测基准测试中比较了DeMo与多个模型在单一代理设置中的表现,如表1所示。为了确保全面和公平的比较,我们分别评估了不同方法在使用和不使用模型集成技术时的性能。结果表明,DeMo显著超越了包括最先进的QCNet及其后期优化增强版SmartRefine 在内的所有之前的方法。具体来说,我们的方法在所有指标上明显优于其他方法,尤其是在minFDE1和minADE1指标上,相对于QCNet分别表现出13.02%和11.83%的性能提升。在使用与其他参赛者类似的集成技术后,DeMo在所有指标上都以较大优势超越了所有方法。随后,我们将DeMo在nuScenes运动预测基准测试中的表现与其他方法进行比较,测试分割结果如表2所示。我们的方法在除minADE5以外的所有指标上也优于其他方法。
4.2 多代理定量结果
在多代理环境中,预测器需要同时预测所有相关代理的未来路径,以全面了解驾驶情况。为了验证我们模型DeMo的有效性,我们在Argoverse 2多代理数据集上进行了测试。结果如表3所示,尽管我们的模型缺少如模型中存在的专门多代理预测功能,但由于我们的新设计,DeMo在所有评估指标上都超越了最近的先进方法。
4.3消融研究
组件的效果
表4展示了我们方法中各组件的有效性。我们在第一行展示了基线方法,类似于之前的方法,该方法使用模式查询生成多模态的未来轨迹。然后,在第二行(ID-2)中,我们直接采用状态查询来解码轨迹,观察到性能下降,这主要是由于查询过多,增加了模型的负担,并且难以区分不同类型查询的含义。在第三行(ID-3)中,我们引入了两个辅助损失,与第一行相比,性能略有提升。尽管模型能够识别每个查询代表的含义,但由于信息有限,性能表现中等。在第四行(ID-4)中,我们加入了图2中的三个聚合模块,但移除了辅助损失,导致性能显著提升。最后,在第五行(ID-5)中,DeMo集成了所有这些技术,达到了出色的表现。
状态序列建模与Mamba的效果
Mamba在序列建模方面表现出色,因此我们使用双向Mamba 来增强不同时间步长间的状态一致性。为了展示其效果,我们将双向Mamba与其他几个模块进行比较,包括单向Mamba、Attention、Conv1d 和 GRU 。如表5左侧所示,由于双向Mamba专门用于序列建模并且能够进行前向和后向扫描,它的配置优于其他模块。
辅助损失和聚合模块的效果
我们进行了消融研究以评估辅助损失和聚合模块的影响。正如表5右侧所示,移除任何这些损失或模块都会导致模型性能下降。值得注意的是,聚合模块的影响比辅助损失更大。这是因为从场景上下文和查询之间学习信息在解耦查询中代表不同含义时至关重要。
状态查询的效果
我们对状态查询的数量进行了消融研究,如表6左侧所示。在默认设置中,我们使用60个状态查询来表示60个时间戳的未来状态。随着状态查询数量的逐渐减少,模型性能下降,因为状态查询的含义越来越模糊。
Attention和Mamba模块深度的效果
Attention和Mamba单元的适当深度配置对于在效率和性能之间实现最佳平衡至关重要。如表6右侧所示,我们对层深度进行了消融研究。结果显示,当Attention单元深度为三层,Mamba单元深度为两层时,性能最佳。
编码器中Mamba模块深度的效果
我们对DeMo编码器中用于编码代理历史信息的Mamba模块进行了消融研究。如表7所示,左侧展示了用于编码代理历史信息的不同模块。我们的目标是聚合历史信息至当前时间点,因此单向Mamba是最合适的选择。
4.4 改进查询解耦测量的分析
我们通过minADE和minFDE来测量状态查询和模式查询的输出,如表8所示。可以看出,来自状态查询输出的轨迹的minADE1和minFDE1优于来自模式查询输出的轨迹。这意味着状态查询中编码了状态动态。此外,模式查询有六条输出轨迹,表明模式查询中主要存储了方向信息。最终输出利用了两者的优势。
4.5 效率分析与定性结果
在模型部署中,平衡性能、推理速度和模型大小非常重要。我们将DeMo与两个最近的代表性模型进行了比较:最先进的QCNet及其通过后期优化增强的SmartRefine。我们的模型大小为5.9M,相比之下,QCNet为7.7M,SmartRefine为8.0M。尽管我们的模型更小,但在性能上表现出显著的优势,详细信息见表1。至于推理速度,我们比较了DeMo和QCNet,它们都是端到端方法。在Argoverse 2单代理验证集上使用NVIDIA GeForce RTX 3090 GPU进行测量,批量大小维持为1。DeMo的平均推理速度仅为38ms,比QCNet的94ms快约2.5倍。这表明我们的方法不仅优于QCNet,而且更高效。
在图3中,我们展示了网络的定性结果。面板(a)显示了缺乏解耦查询范式的基线模型的结果,而面板(b)则展示了DeMo的结果。从前两行可以看出,通过显式优化未来轨迹的动态状态,我们的模型预测的轨迹更准确,更接近地面真值。从第三行可以看出,我们的模型能够更好地捕捉潜在的方向意图。
4.6 与其他方法的计算成本比较
我们在表9中提供了与最近代表性方法的计算成本比较。实验是在Argoverse 2 数据集上使用8个NVIDIA GeForce RTX 3090 GPU进行的。
结论
本文的贡献总结如下:
- 本文提出了一个运动预测框架,将多模态轨迹查询解耦为模式查询和状态查询,分别用于表示方向意图和动态状态。
- 本文设计了基于Attention和Mamba的三个模块,用于处理解耦的模式查询、状态查询以及耦合的模式和状态查询。
- 在Argoverse 2和nuScenes基准数据集上的大量实验表明,DeMo在运动预测中实现了最先进的性能。
论文引用:
DeMo: Decoupling Motion Forecasting intoDirectional Intentions and Dynamic States
最后别忘了,帮忙点“在看”。
您的点赞,在看,是我创作的动力。
AiFighing是全网第一且唯一以代码、项目的形式讲解自动驾驶感知方向的关键技术。
长按扫描下面二维码,加入知识星球。