
HiVT: Hierarchical Vector Transformer for Multi-Agent Motion Prediction
准确预测周围交通智能体的未来运动对于自动驾驶车辆的安全至关重要。最近,矢量化方法因其能够捕捉交通场景中的复杂交互而成为运动预测领域的主流。然而,现有方法忽略了问题的对称性,并且面临昂贵的计算成本,挑战在于在不牺牲预测性能的情况下进行实时多智能体运动预测。为了应对这一挑战,我们提出了层次化矢量变换器(HiVT),用于快速准确的多智能体运动预测。通过将问题分解为局部上下文提取和全局交互建模,我们的方法
摘要
准确预测周围交通智能体的未来运动对于自动驾驶车辆的安全至关重要。最近,矢量化方法因其能够捕捉交通场景中的复杂交互而成为运动预测领域的主流。然而,现有方法忽略了问题的对称性,并且面临昂贵的计算成本,挑战在于在不牺牲预测性能的情况下进行实时多智能体运动预测。为了应对这一挑战,我们提出了层次化矢量变换器(HiVT),用于快速准确的多智能体运动预测。通过将问题分解为局部上下文提取和全局交互建模,我们的方法可以有效地对场景中的大量智能体进行建模。同时,我们提出了一种平移不变场景表示和旋转不变空间学习模块,这些模块提取对场景几何变换具有鲁棒性的特征,并使模型能够在单次前向传播中为多个智能体做出准确预测。实验表明,HiVT在Argoverse运动预测基准测试中实现了最先进的性能,并且模型规模小,能够进行快速的多智能体运动预测。
1.引言
自主车辆在动态环境中安全导航是一项关键任务。为此,自主车辆需要理解周围环境并预测未来的道路情况。然而,准确预测附近交通智能体(如车辆、自行车和行人)的未来运动是具有挑战性的,他们的目标或意图可能是未知的。在多智能体交通场景中,一个智能体的行为受到与其他智能体复杂交互的影响。这些交互进一步与地图依赖的交通规则交织在一起,使得理解场景中多个智能体的多样化行为变得极其困难。
最近,基于学习的方法已经在运动预测任务中展示了它们的有效性。受计算机视觉进展的启发,一些工作将场景光栅化为鸟瞰图,并应用卷积神经网络(CNNs)进行预测。尽管这些方法易于实现,并且可以使用现成的图像模型,但它们计算成本高且感受野有限。鉴于这些限制,最近的工作采用矢量化方法进行更紧凑的场景表示,从轨迹和地图元素中提取一组向量或点。然后,通过图神经网络、变换器或点云模型处理这些场景,以学习矢量化实体(如轨迹航点和车道线段)之间的关系。
然而,现有的矢量化方法在快速变化的交通条件下进行实时运动预测面临挑战。由于矢量化方法通常对参考框架的平移和旋转不具有鲁棒性,为了减轻这个问题,最近的研究将场景标准化,使其以目标智能体为中心并与目标智能体的航向对齐。这种补救措施在需要预测场景中的大量智能体时变得有问题,因为重新标准化场景和重新计算每个目标智能体的场景特征的成本很高。此外,现有工作在空间和时间维度上对所有实体之间的关系进行建模,以捕捉矢量化实体之间的细粒度交互,这随着实体数量的增加不可避免地导致计算量大。
由于实时进行准确预测对自动驾驶的安全至关重要,我们因此受到激励,通过开发一个新框架来推动现有技术水平,实现更快、更准确的多智能体运动预测。简而言之,我们的方法利用了多智能体运动预测问题的对称性和层次结构。我们在多个阶段对运动预测任务进行框架化,并基于变换器层次化地建模实体之间的关系。
在第一阶段,我们的框架避免了昂贵的全对全交互建模,并且仅在局部提取上下文特征。具体来说,我们将场景划分为一组局部区域,每个局部区域以一个建模的智能体为中心。对于每个以智能体为中心的局部区域,我们从包含与中心智能体相关的丰富信息的局部矢量化实体中提取上下文特征。在第二阶段,为了补偿局部接受场的限制并捕获场景中的长距离依赖性,我们通过增强变换器编码器与局部参考框架之间的几何关系,执行不同智能体中心局部区域之间的全局消息传递。最后,给定局部和全局表示,解码器在单次前向传播中为所有智能体产生未来轨迹。
为了进一步利用问题的对称性,我们采用了一种对全局坐标框架的平移不敏感的场景表示,在这种表示中,我们使用相对位置来描述所有矢量化实体。基于这种场景表示,我们引入了旋转不变的交叉注意力模块进行空间学习,这可以学习对场景旋转不变的局部和全局表示。我们的方法具有以下明显优势。首先,通过将问题分解为局部上下文提取和全局交互建模,我们的方法可以逐步聚合不同尺度的信息,并以高效率对场景中的大量实体进行建模。其次,我们的方法可以通过平移不变场景表示和旋转不变空间学习模块学习对输入的平移和旋转具有鲁棒性的表示。第三,与最先进的方法相比,我们的模型可以用更少的参数进行更快、更准确的预测。我们通过在大规模驾驶数据上的广泛实验验证了上述所有优势。我们的代码将公开提供。
2.相关工作
2.1交通场景表示
处理运动预测问题需要从交通场景的元素中学习丰富的表示,包括高清地图和智能体的过去轨迹。大量工作使用光栅化场景作为模型输入,并采用标准图像模型进行学习。具体来说,这些方法从高清地图中提取地图元素(例如,车道边界、人行横道、交通灯),并使用不同的颜色或掩码将场景渲染为鸟瞰图。智能体的过去轨迹要么作为额外的图像通道进行光栅化,要么通过RNNs等时间模型进行处理。光栅化方法与计算机视觉中的成熟技术兼容,但在学习上成本高且效率低。最近,矢量化方法因其高效的稀疏编码和捕捉复杂结构信息的能力而变得流行。与光栅化方法不同,这些方法将场景视为与语义和几何属性相关的实体集合,并学习实体之间的关系。例如,VectorNet通过图神经网络建模车道和轨迹多边线之间的交互,也被一些后续工作用作主干网络。LaneGCN从车道线段构建车道图,并利用多尺度图卷积网络学习图节点的表示。TPCN将点云模型扩展到学习由轨迹航点和车道点组成的时空点集。我们的场景表示也属于这一类,但所有矢量化实体都通过相对位置来表征,使我们的表示对全局坐标框架的平移具有不变性。
2.2运动预测
由于社交交互在交通场景中无处不在,并且显著影响交通智能体的未来运动,许多运动预测方法已经考虑了智能体行为之间的依赖性,并通过社交池化、图神经网络或注意力机制推理智能体-智能体交互。受到变换器模型在各个领域成功的启发,一些最近的工作在运动预测任务中采用变换器来建模空间关系、时间依赖性和智能体与地图元素之间的关系。与现有方法相比,我们的变换器架构通过层次化地学习局部和全局表示来学习。这种层次化策略帮助模型学习多尺度特征,并且比那些沿空间轴和时间轴执行全对全消息传递的方法更有效。此外,我们通过以智能体为中心的表示来建模多个智能体,该表示对场景的平移和旋转具有不变性。层次化架构和对称设计使我们的方法能够以比其他方法更少的参数和更低的计算成本实现最先进的预测性能。
3.方法
3.1总体框架
我们提出的框架概述如图 1 所示。我们首先将交通场景组织为一系列矢量化实体。基于这种场景表示,我们的框架在场景中分层聚合时空信息。在第一阶段,我们为每个智能体编码旋转不变的局部上下文特征。聚合自我运动、邻近智能体的运动和局部地图结构可以提供与被建模智能体相关的丰富信息。在第二阶段,为了补偿局部接受场的限制并捕获场景级动态,我们在全球范围内对不同智能体的局部上下文进行聚合,并通过几何关系增强变换器编码器的能力。最后,层次化学习到的表示被用来为所有智能体同时进行多模态轨迹预测。
3.2场景表示
交通场景由智能体和地图信息组成。为了以结构化的方式表示场景,我们从场景中提取矢量化实体,包括交通智能体的轨迹段和地图数据中的车道段。与涉及绝对位置的先前矢量化方法不同,我们的表示避免使用任何绝对位置,并使用相对位置来表征几何属性,从而使场景成为一个完全的矢量集。具体来说,智能体 i i i 的轨迹被表示为 { p t i − p t − 1 i } t = 1 T \{p_t^i - p_{t-1}^i\}_{t=1}^T {pti−pt−1i}t=1T,其中 p t i ∈ R 2 p_t^i \in \mathbb{R}^2 pti∈R2 是智能体 i i i 在时间步 t t t 的位置, T T T 是总的历史时间步数。对于车道段 ξ \xi ξ,几何属性由 p 1 ξ − p 0 ξ p_1^\xi - p_0^\xi p1ξ−p0ξ 给出,其中 p 0 ξ ∈ R 2 p_0^\xi \in \mathbb{R}^2 p0ξ∈R2 和 p 1 ξ ∈ R 2 p_1^\xi \in \mathbb{R}^2 p1ξ∈R2 分别是车道段 ξ \xi ξ 的起始和结束坐标。通过将点集转换为矢量集,这样的表示自然保证了平移不变性。然而,实体之间的相对位置信息也被丢弃了。为了保留空间关系,我们为智能体-智能体对和智能体-车道对引入了相对位置向量。例如,智能体 j j j 相对于智能体 i i i 在时间步 t t t 的位置向量是 p t j − p t i p_t^j - p_t^i ptj−pti,它完全描述了两个智能体之间的空间关系,并且也是平移不变的。不失一般性,我们的场景表示确保任何应用于它的可学习函数都将必然尊重平移不变性。
3.3层次化矢量变换器
为了在高度动态的环境中准确预测交通智能体的未来轨迹,模型需要有效地学习大量矢量化实体之间的时空关系。变换器已经证明在各种任务中捕捉实体之间的长距离依赖关系是有效的。然而,直接将变换器应用于时空实体会遭受 ( ( N T + L ) 2 ) ((N T + L)^2) ((NT+L)2) 的复杂度,其中 N N N、 T T T 和 L L L 分别是智能体、历史时间步和车道段的数量。为了有效地从大量实体中学习,我们的模型将空间和时间维度分解,并在每个时间步仅局部学习空间关系。具体来说,我们将空间划分为 N N N 个局部区域,每个局部区域以场景中的一个智能体为中心。在每个局部区域内,环境信息涉及中心智能体的邻近智能体的轨迹段和围绕中心智能体的局部车道段。对于每个局部区域,我们通过逐步建模每个时间步的智能体-智能体交互、每个智能体的时间依赖性和当前时间步的智能体-车道交互,将局部信息聚合到单个特征向量中。聚合后,特征向量包含了与中心智能体相关的丰富信息。另一方面,通过分解空间和时间维度,计算复杂度从 ( ( N T + L ) 2 ) ((N T + L)^2) ((NT+L)2) 降低到 ( N T 2 + T N 2 + N L ) (N T^2 + T N^2 + N L) (NT2+TN2+NL),并通过限制局部区域的半径进一步降低到 ( N T 2 + T N k + N ℓ ) (N T^2 + T N k + N \ell) (NT2+TNk+Nℓ),其中 k < N k < N k<N 和 ℓ < L \ell < L ℓ<L。虽然局部编码器可以局部学习丰富的表示,但信息量受到局部区域范围的限制。为了避免牺牲预测性能,我们进一步采用全局交互模块来补偿局部接受场的限制,并捕获场景级动态,其中我们执行局部区域之间的消息传递。全局交互模块可以显著增强模型的表达能力,代价是 ( N 2 ) (N^2) (N2) 的复杂度,与局部编码器相比相对较轻。多智能体运动预测问题表现出平移和旋转对称性。现有方法通过针对每个智能体重新标准化所有矢量化实体并多次进行单智能体预测来实现不变性。相比之下,我们的模型可以一次性为所有智能体进行预测,而不会牺牲不变性,这是通过使用平移不变的场景表示和旋转不变的空间学习模块实现的。我们将在以下部分更详细地说明模型组件。
3.3.1局部编码器
智能体-智能体交互。智能体-智能体交互模块的目标是为每个局部区域的每个时间步学习中心智能体与邻近智能体之间的关系。为了利用问题的对称性,我们引入了一个旋转不变的交叉注意力块来聚合空间信息。具体来说,我们统一地将中心智能体 i i i 的最新轨迹段 p T i − p T − 1 i p_T^i - p_{T-1}^i pTi−pT−1i 作为局部区域的参考向量,并根据参考向量的方向 θ i \theta_i θi 旋转所有局部向量。基于旋转后的向量及其相关的语义属性,我们使用多层感知机(MLP)来获得中心智能体 i i i 的嵌入 z t i ∈ R d h z_t^i \in \mathbb{R}^{d_h} zti∈Rdh 以及任何邻近智能体 j j j 的嵌入 z t i j ∈ R d h z_t^{ij} \in \mathbb{R}^{d_h} ztij∈Rdh 在任何时间步 t t t:
z t i = ϕ center ( R i ⊤ ( p t i − p t − 1 i ) , a i ) , z_t^i = \phi_{\text{center}} (R_i^\top (p_t^i - p_{t-1}^i), a_i), zti=ϕcenter(Ri⊤(pti−pt−1i),ai),
z t i j = ϕ nbr ( R i ⊤ ( p t j − p t − 1 j ) , R i ⊤ ( p t j − p t i ) , a j ) , z_t^{ij} = \phi_{\text{nbr}} (R_i^\top (p_t^j - p_{t-1}^j), R_i^\top (p_t^j - p_t^i), a_j), ztij=ϕnbr(Ri⊤(ptj−pt−1j),Ri⊤(ptj−pti),aj),
其中 ϕ center ( ⋅ ) \phi_{\text{center}}(·) ϕcenter(⋅) 和 ϕ nbr ( ⋅ ) \phi_{\text{nbr}}(·) ϕnbr(⋅) 是两个不同的 MLP 块, R i ∈ R 2 × 2 R_i \in \mathbb{R}^{2 \times 2} Ri∈R2×2 是由 θ i \theta_i θi 参数化的旋转矩阵, a i a_i ai 和 a j a_j aj 分别是智能体 i i i 和 j j j 的语义属性。由于所有几何属性在输入 MLP 之前都是相对于中心智能体进行归一化的,这些嵌入不受全局坐标框架旋转的影响。除了轨迹段之外, ϕ nbr ( ⋅ ) \phi_{\text{nbr}}(·) ϕnbr(⋅) 的输入还包括邻近智能体相对于中心智能体的位置向量,使邻近嵌入在空间上具有意识。然后将中心智能体的嵌入转换为查询向量,并将邻近智能体的嵌入用于计算键和值向量:
q t i = W Q space z t i , k t i j = W K space z t i j , v t i j = W V space z t i j , q_t^i = W_{Q_{\text{space}}} z_t^i, \quad k_t^{ij} = W_{K_{\text{space}}} z_t^{ij}, \quad v_t^{ij} = W_{V_{\text{space}}} z_t^{ij}, qti=WQspacezti,ktij=WKspaceztij,vtij=WVspaceztij,
其中 W Q space , W K space , W V space ∈ R d k × d h W_{Q_{\text{space}}}, W_{K_{\text{space}}}, W_{V_{\text{space}}} \in \mathbb{R}^{d_k \times d_h} WQspace,WKspace,WVspace∈Rdk×dh 是可学习的线性投影矩阵, d k d_k dk 是变换后的向量维度。得到的查询、键和值向量被用作缩放点积注意力块的输入:
α t i = softmax ( q t i ⊤ d k ⋅ { k t i j } j ∈ N i ) , \alpha_t^i = \text{softmax} \left( \frac{q_t^{i\top} \sqrt{d_k} \cdot \{ k_t^{ij} \}}{j \in N_i} \right), αti=softmax(j∈Niqti⊤dk⋅{ktij}),
m t i = ∑ j ∈ N i α t i j v t i j , m_t^i = \sum_{j \in N_i} \alpha_t^{ij} v_t^{ij}, mti=j∈Ni∑αtijvtij,
g t i = sigmoid ( W gate [ z t i , m t i ] ) , g_t^i = \text{sigmoid} (W_{\text{gate}} [z_t^i, m_t^i]), gti=sigmoid(Wgate[zti,mti]),
z ^ t i = g t i ⊙ W self z t i + ( 1 − g t i ) ⊙ m t i , \hat{z}_t^i = g_t^i \odot W_{\text{self}} z_t^i + (1 - g_t^i) \odot m_t^i, z^ti=gti⊙Wselfzti+(1−gti)⊙mti,
其中 N i N_i Ni 是智能体 i i i 的邻居集合, W gate W_{\text{gate}} Wgate 和 W self W_{\text{self}} Wself 是可学习的矩阵, ⊙ \odot ⊙ 表示逐元素乘积。与标准缩放点积注意力相比,我们的变体使用门控函数将环境特征 m t i m_t^i mti 与中心智能体的特征 z t i z_t^i zti 融合,使块能够更控制特征更新。像原始变换器架构一样,我们的注意力块也可以扩展到多头。多头注意力块的输出通过 MLP 块传递,以获得智能体 i i i 在时间步 t t t 的空间嵌入 s t i ∈ R d h s_t^i \in \mathbb{R}^{d_h} sti∈Rdh。此外,我们在每个块之前应用层归一化,在每个块之后应用残差连接。在实践中,这个模块可以使用高效的散射和聚集操作来并行化所有局部区域和所有时间步的学习。时间依赖性。为了进一步捕获每个局部区域的时间信息,我们在智能体-智能体交互模块之上使用时间变换器编码器。对于任何中心智能体 i i i,这个模块的输入序列由智能体-智能体交互模块在不同时间步返回的嵌入 s t i t = 1 T {s_t^i}_{t=1}^T stit=1T 组成。
时间依赖性。为了进一步捕获每个局部区域的时间信息,我们在智能体-智能体交互模块之上使用时间变换器编码器。对于任何中心智能体 i i i,这个模块的输入序列由智能体-智能体交互模块在不同时间步返回的嵌入 { s t i } t = 1 T \{s_t^i\}_{t=1}^T {sti}t=1T 组成。类似于 BERT,我们在这个输入序列的末尾添加了一个额外的可学习标记 s T + 1 ∈ R d h s_{T+1} \in \mathbb{R}^{d_h} sT+1∈Rdh。然后,我们为所有标记添加可学习的位置上嵌入,并把它们堆叠成一个矩阵 S i ∈ R ( T + 1 ) × d h S_i \in \mathbb{R}^{(T+1) \times d_h} Si∈R(T+1)×dh,该矩阵被送入时间注意力块:
Q i = S i W Q time , K i = S i W K time , V i = S i W V time , Q_i = S_i W_{Q_{\text{time}}}, \quad K_i = S_i W_{K_{\text{time}}}, \quad V_i = S_i W_{V_{\text{time}}}, Qi=SiWQtime,Ki=SiWKtime,Vi=SiWVtime,
其中 W Q time , W K time , W V time ∈ R d h × d k W_{Q_{\text{time}}}, W_{K_{\text{time}}}, W_{V_{\text{time}}} \in \mathbb{R}^{d_h \times d_k} WQtime,WKtime,WVtime∈Rdh×dk 是可学习的矩阵, M ∈ R ( T + 1 ) × ( T + 1 ) M \in \mathbb{R}^{(T+1) \times (T+1)} M∈R(T+1)×(T+1) 是一个时间掩码,它强制标记只关注前面的时间步。时间学习模块也由交替的多头注意力块和 MLP 块组成。我们输入更新后的额外标记,这些标记总结了局部区域的时空特征,到后续的智能体-车道交互模块中。
智能体-车道交互。局部地图结构可以指示中心智能体的未来意图。因此,我们把局部地图信息整合到嵌入中。我们首先旋转局部车道段和当前时间步 T T T 的智能体-车道相对位置向量。然后,通过 MLP 对旋转后的向量进行编码:
z i ξ = ϕ lane ( R i ⊤ ( p 1 ξ − p 0 ξ ) , R i ⊤ ( p 0 ξ − p T i ) , a ξ ) , z_{i\xi} = \phi_{\text{lane}} \left( R_i^\top \left( p_1^\xi - p_0^\xi \right), R_i^\top \left( p_0^\xi - p_T^i \right), a_\xi \right), ziξ=ϕlane(Ri⊤(p1ξ−p0ξ),Ri⊤(p0ξ−pTi),aξ),
其中 ϕ lane ( ⋅ ) \phi_{\text{lane}}(\cdot) ϕlane(⋅) 是车道段的 MLP 编码器, R i ∈ R 2 × 2 R_i \in \mathbb{R}^{2 \times 2} Ri∈R2×2 是智能体 i i i 局部区域的旋转矩阵, p 0 ξ ∈ R 2 p_0^\xi \in \mathbb{R}^2 p0ξ∈R2, p 1 ξ ∈ R 2 p_1^\xi \in \mathbb{R}^2 p1ξ∈R2,和 a ξ a_\xi aξ 分别是车道段 ξ \xi ξ 的起始位置、结束位置和语义属性。以中心智能体的空间-时间特征作为查询输入,以 MLP 编码的车道段特征作为键/值输入,智能体-车道注意力的计算方式与上述公式(3)至公式(7)相同。我们进一步应用另一个 MLP 块来获得中心智能体 i i i 的最终局部嵌入 h i ∈ R d h h_i \in \mathbb{R}^{d_h} hi∈Rdh。在顺序建模智能体-智能体交互、时间依赖性和智能体-车道交互之后,嵌入已经融合了与局部区域中心智能体相关的丰富信息。
3.3.2 全局交互模块
我们引入全局交互模块以捕获场景中的长距离依赖性。由于局部特征是在以智能体为中心的坐标框架中提取的,全局交互模块在执行局部区域间的消息传递时需要考虑局部坐标框架之间的几何关系。为此,我们扩展变换器编码器以识别局部参考框架之间的差异。例如,智能体 i i i 和智能体 j j j 的坐标框架之间的差异可以通过 p T j − p T i p_T^j - p_T^i pTj−pTi 和 Δ θ i j \Delta \theta_{ij} Δθij 参数化,其中 Δ θ i j \Delta \theta_{ij} Δθij 表示 θ j − θ i \theta_j - \theta_i θj−θi。在从智能体 j j j 向智能体 i i i 传递消息时,我们使用 MLP ϕ rel ( ⋅ ) \phi_{\text{rel}}(\cdot) ϕrel(⋅) 来获得成对嵌入 e i j e_{ij} eij:
e i j = ϕ rel ( R i ⊤ ( p T j − p T i ) , cos ( Δ θ i j ) , sin ( Δ θ i j ) ) . e_{ij} = \phi_{\text{rel}} \left( R_i^\top (p_T^j - p_T^i), \cos(\Delta \theta_{ij}), \sin(\Delta \theta_{ij}) \right). eij=ϕrel(Ri⊤(pTj−pTi),cos(Δθij),sin(Δθij)).
成对嵌入随后被纳入向量的转换中:
q ~ i = W Q global h i , k ~ i j = W K global [ h j , e i j ] , v ~ i j = W V global [ h j , e i j ] , \tilde{q}_i = W_{Q_{\text{global}}} h_i, \quad \tilde{k}_{ij} = W_{K_{\text{global}}} [h_j, e_{ij}], \quad \tilde{v}_{ij} = W_{V_{\text{global}}} [h_j, e_{ij}], q~i=WQglobalhi,k~ij=WKglobal[hj,eij],v~ij=WVglobal[hj,eij],
其中 h i h_i hi 和 h j h_j hj 是智能体 i i i 和 j j j 的局部嵌入, W Q global , W K global , W V global W_{Q_{\text{global}}}, W_{K_{\text{global}}}, W_{V_{\text{global}}} WQglobal,WKglobal,WVglobal 是可学习的矩阵。为了捕获全局成对交互,我们应用与局部编码器中相同的空间注意力块,后跟一个 MLP 块,输出任何智能体 i i i 的全局表示 h ~ i \tilde{h}_i h~i。
3.3.3 多模态未来解码器
交通智能体的未来运动本质上是多模态的。因此,我们参数化未来轨迹的分布为一个混合模型,其中每个混合分量是一个拉普拉斯分布。预测是为所有智能体一次性完成的。对于每个智能体 i i i 和每个分量 f f f,一个 MLP 接收局部和全局表示作为输入,并输出智能体在局部坐标框架中每个未来时间步的位置 μ t i , f ∈ R 2 \mu_t^{i,f} \in \mathbb{R}^2 μti,f∈R2 及其相关的不确定性 b t i , f ∈ R 2 b_t^{i,f} \in \mathbb{R}^2 bti,f∈R2。回归头的输出张量具有形状 [ F , N , H , 4 ] [F, N, H, 4] [F,N,H,4],其中 F F F 是混合分量的数量, N N N 是场景中的智能体数量, H H H 是预测的未来时间步数。我们还使用另一个 MLP 后跟一个 softmax 函数来为每个智能体产生混合模型的混合系数,其形状为 [ N , F ] [N, F] [N,F]。
3.4 训练
我们采用多样性损失来鼓励多个轨迹假设的多样性,这在训练期间仅优化模型预测的 F F F 个轨迹中的最优轨迹。在优化之前,我们首先计算模型为每个智能体和每个时间步预测的 F F F 个混合分量的位置与真实位置之间的误差。然后,我们对所有未来时间步的误差求和,以获得一个形状为 [ F , N ] [F, N] [F,N] 的误差矩阵,根据该矩阵我们为每个智能体选择误差最小的轨迹,即找到误差矩阵每列的最小值。最终的损失函数由回归损失 L reg L_{\text{reg}} Lreg 和分类损失 L cls L_{\text{cls}} Lcls 组成,权重相等:
L = L reg + L cls . L = L_{\text{reg}} + L_{\text{cls}}. L=Lreg+Lcls.
我们采用负对数似然作为回归损失:
L reg = − 1 N H ∑ i = 1 N ∑ t = T + 1 T + H log P ( R i ⊤ ( p t i − p T i ) ∣ μ ^ t i , b ^ t i ) , L_{\text{reg}} = -\frac{1}{NH} \sum_{i=1}^N \sum_{t=T+1}^{T+H} \log P \left( R_i^\top (p_t^i - p_T^i) \mid \hat{\mu}_t^i, \hat{b}_t^i \right), Lreg=−NH1i=1∑Nt=T+1∑T+HlogP(Ri⊤(pti−pTi)∣μ^ti,b^ti),
其中 P ( ⋅ ∣ ⋅ ) P(\cdot \mid \cdot) P(⋅∣⋅) 是拉普拉斯分布的概率密度函数, { μ ^ t i } t = T + 1 T + H \{\hat{\mu}_t^i\}_{t=T+1}^{T+H} {μ^ti}t=T+1T+H 和 { b ^ t i } t = T + 1 T + H \{\hat{b}_t^i\}_{t=T+1}^{T+H} {b^ti}t=T+1T+H 分别是智能体 i i i 的最佳预测轨迹的位置和不确定性。我们使用交叉熵损失作为分类损失,以优化混合系数。
4实验
4.1实验设置
4.1.1数据集
我们在大规模的Argoverse运动预测数据集上评估我们的预测框架,该数据集提供了智能体的轨迹和高清地图数据。数据集包含323557个真实世界的驾驶场景,并被划分为训练集、验证集和测试集,分别包含205942、39472和78143个样本。所有训练和验证场景都是5秒的序列,以10Hz的频率采样,而测试集中仅公开了前2秒的轨迹。给定最初的2秒观测,Argoverse运动预测挑战赛要求预测智能体未来3秒的运动。
4.1.2评估指标
我们使用标准的运动预测指标来评估我们的模型,包括最小平均位移误差(minADE)、最小最终位移误差(minFDE)和未命中率(MR)。这些指标允许模型为每个智能体预测多达6条轨迹。minADE指标衡量最佳预测轨迹与真实轨迹在所有未来时间步上的平均 ℓ 2 \ell_2 ℓ2距离,而minFDE衡量最终未来时间步的误差。最佳预测轨迹定义为具有最小端点误差的轨迹。MR指的是真实端点与最佳预测端点之间的距离超过2.0米的场景比率。
4.1.3实现细节
我们在RTX 2080 Ti GPU上使用AdamW优化器训练我们的模型64个周期,批量大小、初始学习率、权重衰减和dropout率分别设置为32、 3 × 1 0 − 4 3 \times 10^{-4} 3×10−4、 1 × 1 0 − 4 1 \times 10^{-4} 1×10−4和0.1。学习率使用余弦退火调度器进行衰减。我们的模型由1层智能体-智能体和智能体-车道交互模块、4层时间学习模块和3层全局交互模块组成。所有多头注意力块的头数为8。所有局部区域的半径为50米。我们遵循基线的惯例,将预测模式的数量 F F F设置为6。我们不使用集成方法和数据增强等技巧。我们基于具有64个隐藏单元的小模型和具有128个隐藏单元的大模型进行实验,分别称为HiVT-64和HiVT-128。
4.2消融研究
我们在Argoverse验证集上进行消融研究。除非特别说明,实验结果基于我们的64维模型HiVT-64。
每个模块的重要性。我们通过交替移除一个组件来展示每个模块对预测性能的贡献。如表1所示,每个组件都能在一定程度上提高性能。首先,没有智能体-智能体交互模块,模型无法捕获先前时间步的细粒度局部交互,导致性能下降。我们还注意到,增加此模块的层数可以进一步提高性能,但我们为了更高的效率保持使用一层。其次,时间学习模块对性能影响最大,因为在高度动态的交通场景中推断交通智能体的未来运动严重依赖历史信息。第三,车道信息在运动预测中起着至关重要的作用,因为交通智能体通常由于交通规则的约束而沿车道移动。此外,全局交互模块可以显著提高预测性能。这一结果验证了其捕获长距离依赖性的能力。
消融研究对注意力块的影响。我们评估了空间注意力块中的门控更新函数和时间注意力块中的时间掩码的影响。如表2所示,使用门控函数可以提高预测性能,可能是因为一些智能体与环境的交互并不多。表2的结果还表明,移除时间注意力块中的时间掩码会导致性能变差,这表明阻止标记关注后续时间步对模型是有益的。
4.3结果
与最先进技术的比较。我们在Argoverse测试集上与最先进模型进行比较。表4中的结果收集自2021年11月16日的Argoverse排行榜。HOME+GOHOME是表中唯一的光栅化方法,它使用的参数比大多数矢量化方法多,但在MR指标除外的性能上并不出色。HiVT-64使用比LaneGCN、mmTransformer和DenseTNT少82.1%、74.6%和40.0%的参数,在minADE和minFDE指标上显著优于它们。与Scene Transformer和MultiModalTransformer相比,HiVT-64使用比它们少95.7%和89.5%的参数,但仍然实现了相当或更好的性能。HiVT-128使用比Scene Transformer和MultiModalTransformer少83.5%和60.0%的参数,在minADE和minFDE方面优于表4中显示的所有方法。以上结果表明,我们的方法在预测性能和参数效率方面具有优越性。我们的方法在2021年11月16日的minADE方面排名第一,在Argoverse排行榜上保持竞争力。
推理速度。我们在Argoverse验证集上使用RTX 2080 Ti GPU和批量大小为32比较模型的推理速度。这样的批量大小接近每个场景的平均智能体数量。如表5所示,我们模型的所有变体都比基线快,当局部区域的半径不小于20米时,完整模型的预测精度超过了基线。尽管在表5中我们假设需要多次前向传递来进行多智能体预测,并显示了批量大小为32时的推理速度,但我们的方法实际上可以一次性为所有智能体进行准确预测,而不会牺牲不变性,这意味着快速和准确的多智能体预测。当批量大小为1且半径为50米时,HiVT-128的平均推理速度约为20毫秒,满足实时要求。
从表5中我们可以看到,添加全局交互模块引入的计算成本可以忽略不计,但显著提高了预测性能。这一结果验证了全局交互模块的有效性。我们还改变了局部区域的半径,以获得具有不同计算复杂度的模型。表5显示,减小半径可以加速整体模型的推理,使用过大的半径80米会导致推理变慢,但对性能没有帮助。我们的局部-全局架构允许实践者根据预测精度的要求和计算资源的限制选择合适的局部区域大小。定性结果。我们在Argoverse验证集上展示了HiVT-128的定性结果。为了清晰起见,我们每场景仅可视化两个智能体。如图3所示,我们的模型可以同时为复杂交通场景中的多个智能体进行准确、多模态和合理的预测。有趣的是,尽管数据集不包含有关交通灯状态的信息,左上角的示例表明我们的模型成功预测了车辆在交叉口的突然加速。
更多推荐
所有评论(0)