CMX跨模态融合实战:用PyTorch复现RGB-X语义分割中的Transformer模块
最近在做一个多传感器融合的感知项目,团队里的小伙伴一直在讨论如何让RGB图像和深度、热成像这些“X模态”数据更好地协同工作。传统的多模态融合方法,要么简单地在输入层拼接,要么用两个独立的网络各自为政,效果总是不尽如人意。直到我们尝试了基于Transformer的CMX框架,那种特征间“互相理解、互相校正”的融合方式,才真正让模型性能上了一个台阶。这篇文章,我就从一个工程实践者的角度,带你手把手拆解CMX的核心模块——特征校正模块(CM-FRM)和特征融合模块(FFM),并用PyTorch把它们复现出来。无论你是想在自己的语义分割任务中集成多模态能力,还是单纯对Transformer在视觉融合中的应用感兴趣,相信这篇深度解析都能给你带来不少启发。
1. 理解CMX:为何双流Transformer是RGB-X融合的利器
在自动驾驶、机器人导航或者工业检测这些场景里,单一的RGB摄像头已经越来越难以满足复杂环境下的感知需求。深度相机能提供精确的几何距离,热成像能穿透烟雾、无视光照变化,事件相机对高速运动极其敏感。这些“X模态”数据与RGB图像天然互补,但如何让它们“1+1>2”,却是个老大难问题。
过去的方法大致分两种:一种是“早融合”,直接把不同模态的数据在输入层堆叠起来,喂给一个网络。这种方法简单粗暴,但问题在于,网络底层很难学会区分和处理来自不同传感器的、具有不同统计特性的噪声。另一种是“晚融合”,用两个独立的骨干网络分别提取特征,最后在高层进行融合。这种方式虽然尊重了各模态的特性,但特征间的交互太晚,往往错过了在中间层进行深度互补的机会。
CMX提出的双流Transformer架构,巧妙地走了中间路线。它保留了双流设计,让RGB和X模态拥有独立的特征提取路径,但在特征提取的过程中,就通过精心设计的模块进行密集的交互。这就像让两个专家在各自专精领域深耕的同时,不断交换笔记、互相提问,最终形成的报告自然比各自写完再拼凑要深刻得多。
其核心在于两个模块:
- 特征校正模块(CM-FRM):在空间和通道两个维度上,动态计算一个模态对另一个模态的“注意力权重”,用来自另一个模态的、经过校准的信息来增强当前模态的特征。这有效抑制了单一模态中的噪声和不确定性。
- 特征融合模块(FFM):在准备进行最终融合前,先让两个模态的特征进行一轮全局的、基于交叉注意力的“深度对话”,然后再通过高效的卷积操作将它们合二为一。
这种设计带来的最大好处是通用性。无论你的“X”是深度图、热力图还是激光雷达投影,CMX的交互机制都能工作,因为你不需要为每种模态设计特定的融合策略,Transformer的注意力机制自动学习如何建立模态间的关联。
2. 基石构建:动手实现特征校正模块(CM-FRM)
CM-FRM模块的直觉非常巧妙:它不认为两个模态的特征是平等的。相反,它让每个模态都“审视”一下对方,找出对方特征图中哪些位置、哪些通道的信息对自己是有益的,然后有选择性地吸收过来。这个过程在空间和通道两个维度上同时进行。
2.1 通道权重的计算:全局信息的提炼
通道注意力关注的是“什么样的特征通道更重要”。CMX这里采用了一个经典而有效的组合:同时利用平均池化和最大池化来聚合空间信息,因为前者能捕捉整体背景,后者对显著的独特特征更敏感。
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
import math
class ChannelWeights(nn.Module):
def __init__(self, dim, reduction=4):
super(ChannelWeights, self).__init__()
self.dim = dim
# 自适应池化,无论输入特征图多大,输出都是1x1
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
# MLP用于计算权重。输入是拼接后的4*dim维向量,输出是2*dim维的权重
self.mlp = nn.Sequential(
nn.Linear(self.dim * 4, self.dim * 4 // reduction),
nn.ReLU(inplace=True), # 使用inplace节省内存
nn.Linear(self.dim * 4 // reduction, self.dim * 2),
nn.Sigmoid() # 输出压缩到[0,1]区间,作为权重
)
def forward(self, x1, x2):
B, C, H, W = x1.shape
# 将两个模态的特征在通道维度拼接
x = torch.cat((x1, x2), dim=1) # shape: [B, 2*C, H, W]
avg = self.avg_pool(x).view(B, self.dim * 2) # [B, 2*C]
max = self.max_pool(x).view(B, self.dim * 2) # [B, 2*C]
# 拼接平均和最大池化结果,获得更全面的全局描述
y = torch.cat((avg, max), dim=1) # [B, 4*C]
y = self.mlp(y).view(B, self.dim * 2, 1) # [B, 2*C, 1]
# 将权重重新整形,分离出两个模态各自的通道权重
# 输出形状: [2, B, C, 1, 1],其中第一维0对应x1的权重,1对应x2的权重
channel_weights = y.reshape(B, 2, self.dim, 1, 1).permute(1, 0, 2, 3, 4)



被折叠的 条评论
为什么被折叠?



