049、Swin Transformer Block 替换 Backbone 后两层 C3k2:窗口注意力的层次化设计
一、从一次诡异的mAP下降说起
上个月调一个工业缺陷检测模型,YOLOv11s在PCB板数据集上跑了50轮,mAP@0.5卡在78.3%死活上不去。我盯着TensorBoard里的特征图可视化,发现Backbone最后两层输出的特征图在空间维度上几乎“糊”成一片——小尺寸缺陷(比如0.5mm的划痕)完全被背景噪声淹没了。
当时第一反应是加注意力机制,但SE、CBAM试了一圈,提升不到0.5个点。后来翻Swin Transformer论文时突然意识到:YOLOv11的C3k2模块本质是密集残差连接,对局部细节的建模能力其实够,但缺少跨窗口的全局交互。而Backbone最后两层特征图分辨率已经降到20x20和10x10,这时候用窗口注意力反而比全局自注意力更划算——计算量小,还能保留空间结构。
于是动手把最后两个C3k2替换成Swin Transformer Block。结果mAP直接跳到81.7%,涨了3.4个点。但别高兴太早,第一次跑的时候mAP反而掉了0.8%,后来发现是窗口划分的padding没处理好。下面把踩过的坑和最终方案拆开讲。
二、Swin Block的核心设计:别把窗口注意力当成黑盒
Swin Transformer Block和C3k2最大的区别在于:C3k2是“通道混合+残差”,Swin Block是“空间划分+移位窗口”。替换后两层时,需要特别注意两点:
- 窗口大小必须能被特征图尺寸整除。Backbone最后两层是20x20和10x10,窗口大小设4x4的话,20/4=5刚好整除,但10/4=2.5会出问题。我一开始没处理,结果PyTorch的
window_partition函数直接报错。 - 移位窗口的cyclic shift。Swin论文里用
torch.roll实现,但YOLO的推理流程里如果用了torch.jit.script,roll操作会被优化掉导致结果不对。后面会给出替代方案。
三、代码实现:从C3k2到Swin Block的替换手术
3.1 先定义Swin Transformer Block的核心组件
import torch
import torch.nn as nn
import torch.nn.functional as F
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True):
super().__init__()
self.dim = dim
self.window_size = window_size # (Wh, Ww)
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# 这里踩过坑:qkv的bias必须保留,否则小模型收敛慢
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
# 相对位置偏置表,别写成nn.ParameterList,直接nn.Parameter
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
# 计算相对位置索引
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
def forward(self, x, mask=None):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # 别写成qkv.unbind(0),显存会炸
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
# 相对位置偏置
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
3.2 窗口划分与合并:这里最容易出bug
def window_partition(x, window_size):
# x: (B, H, W, C)
B, H, W, C = x.shape
# 别这样写:直接view会报错,因为H和W可能不能被window_size整除
# 正确做法:先pad
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) # 注意pad顺序:左、右、上、下
H_pad, W_pad = x.shape[1], x.shape[2]
x = x.view(B, H_pad // window_size, window_size, W_pad // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C)
return windows, (H_pad, W_pad)
def window_reverse(windows, window_size, H, W, pad_h, pad_w):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
# 去掉padding
if pad_h > 0 or pad_w > 0:
x = x[:, :H-pad_h, :W-pad_w, :].contiguous()
return x
3.3 Swin Transformer Block主体
class SwinTransformerBlock(nn.Module):
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
# 这里踩过坑:shift_size不能大于window_size
if min(self.input_resolution) <= self.window_size:
self.shift_size = 0
self.window_size = min(self.input_resolution)
self.norm1 = nn.LayerNorm(dim)
self.attn = WindowAttention(dim, window_size=(self.window_size, self.window_size), num_heads=num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, 4 * dim),
nn.GELU(),
nn.Linear(4 * dim, dim)
)
# 计算attention mask
if self.shift_size > 0:
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1))
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows, _ = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
# x: (B, C, H, W)
B, C, H, W = x.shape
shortcut = x
x = x.flatten(2).transpose(1, 2) # (B, H*W, C)
x = self.norm1(x)
x = x.view(B, H, W, C)
# 循环移位
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# 划分窗口
windows, (H_pad, W_pad) = window_partition(shifted_x, self.window_size)
# 窗口注意力
attn_windows = self.attn(windows, mask=self.attn_mask)
# 合并窗口
shifted_x = window_reverse(attn_windows, self.window_size, H_pad, W_pad,
pad_h=H_pad-H, pad_w=W_pad-W)
# 反向移位
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H*W, C)
x = self.norm2(x)
x = self.mlp(x)
x = x.transpose(1, 2).view(B, C, H, W)
return x + shortcut
3.4 替换YOLOv11 Backbone的最后两层
找到ultralytics/nn/modules/block.py中的C3k2类,在ultralytics/nn/modules/backbone.py中定位到最后两个stage。假设原始Backbone结构如下:
Stage3: C3k2(256, 512, 3, True) # 输出20x20
Stage4: C3k2(512, 512, 3, True) # 输出10x10
替换为:
# 在backbone的__init__中
self.stage3 = SwinTransformerBlock(
dim=512,
input_resolution=(20, 20),
num_heads=8,
window_size=4, # 20/4=5,整除
shift_size=2 # 窗口大小的一半
)
self.stage4 = SwinTransformerBlock(
dim=512,
input_resolution=(10, 10),
num_heads=8,
window_size=4, # 这里注意:10不能被4整除,但SwinBlock内部会处理padding
shift_size=2
)
关键修改点:在forward中,需要把C3k2的输入输出通道对齐。原始C3k2的输入是256通道输出512,而SwinBlock要求输入输出通道一致。所以需要在Stage3前加一个1x1卷积升维:
self.stage3_conv = nn.Conv2d(256, 512, 1)
四、消融实验:窗口大小和移位策略的影响
在PCB缺陷数据集上(训练集5000张,测试集1000张,8类缺陷),用YOLOv11s做基准,替换最后两层Swin Block,训练100轮,输入640x640:
| 配置 | mAP@0.5 | mAP@0.5:0.95 | 参数量 | 推理速度(ms) |
|---|---|---|---|---|
| 原始C3k2 | 78.3% | 52.1% | 9.2M | 2.1 |
| Swin Block (window=4, shift=0) | 79.8% | 53.6% | 10.1M | 2.8 |
| Swin Block (window=4, shift=2) | 81.7% | 55.4% | 10.1M | 3.0 |
| Swin Block (window=7, shift=3) | 80.2% | 53.9% | 10.1M | 3.5 |
| Swin Block (window=8, shift=4) | 79.5% | 52.8% | 10.1M | 3.8 |
结论:
- 窗口大小4+移位2效果最好,因为20x20特征图用4x4窗口刚好5x5个窗口,移位后能覆盖所有空间位置
- 窗口太大(7或8)反而下降,因为小特征图下窗口内像素太少,自注意力退化成平均池化
- 推理速度增加约1ms,但mAP涨3.4个点,性价比很高
五、个人经验:三个容易忽略的细节
-
LayerNorm的位置:Swin Block的LayerNorm放在attention和MLP之前(pre-norm),而C3k2用的是BatchNorm。替换后如果发现训练不稳定,检查一下BN层的running_mean是否被冻结——我遇到过因为
model.train()没正确设置导致BN统计量不更新,mAP直接掉5个点。 -
梯度检查点:Swin Block的显存占用比C3k2高30%左右,如果batch size设8爆显存,可以在SwinBlock的forward里加
torch.utils.checkpoint.checkpoint。但注意:checkpoint不支持torch.roll操作,需要把移位部分单独拎出来。 -
混合精度训练:Swin Block里的softmax在fp16下容易溢出,建议在
WindowAttention.forward里把attn转成fp32计算,再转回fp16。代码里加一行attn = attn.float().softmax(dim=-1).half()就能解决。
最后说句实在话:Swin Block替换Backbone后两层不是万能药。如果你的数据集里目标尺寸都很大(比如行人检测),窗口注意力带来的提升可能不到1个点。但如果你做的是小目标检测(比如遥感图像、工业缺陷),这个改动值得一试——至少在我的三个项目里都稳定涨点。

145

被折叠的 条评论
为什么被折叠?



