049、Swin Transformer Block 替换 Backbone 后两层 C3k2:窗口注意力的层次化设计

049、Swin Transformer Block 替换 Backbone 后两层 C3k2:窗口注意力的层次化设计

一、从一次诡异的mAP下降说起

上个月调一个工业缺陷检测模型,YOLOv11s在PCB板数据集上跑了50轮,mAP@0.5卡在78.3%死活上不去。我盯着TensorBoard里的特征图可视化,发现Backbone最后两层输出的特征图在空间维度上几乎“糊”成一片——小尺寸缺陷(比如0.5mm的划痕)完全被背景噪声淹没了。

当时第一反应是加注意力机制,但SE、CBAM试了一圈,提升不到0.5个点。后来翻Swin Transformer论文时突然意识到:YOLOv11的C3k2模块本质是密集残差连接,对局部细节的建模能力其实够,但缺少跨窗口的全局交互。而Backbone最后两层特征图分辨率已经降到20x20和10x10,这时候用窗口注意力反而比全局自注意力更划算——计算量小,还能保留空间结构。

于是动手把最后两个C3k2替换成Swin Transformer Block。结果mAP直接跳到81.7%,涨了3.4个点。但别高兴太早,第一次跑的时候mAP反而掉了0.8%,后来发现是窗口划分的padding没处理好。下面把踩过的坑和最终方案拆开讲。

二、Swin Block的核心设计:别把窗口注意力当成黑盒

Swin Transformer Block和C3k2最大的区别在于:C3k2是“通道混合+残差”,Swin Block是“空间划分+移位窗口”。替换后两层时,需要特别注意两点:

  1. 窗口大小必须能被特征图尺寸整除。Backbone最后两层是20x20和10x10,窗口大小设4x4的话,20/4=5刚好整除,但10/4=2.5会出问题。我一开始没处理,结果PyTorch的window_partition函数直接报错。
  2. 移位窗口的cyclic shift。Swin论文里用torch.roll实现,但YOLO的推理流程里如果用了torch.jit.scriptroll操作会被优化掉导致结果不对。后面会给出替代方案。

三、代码实现:从C3k2到Swin Block的替换手术

3.1 先定义Swin Transformer Block的核心组件

import torch
import torch.nn as nn
import torch.nn.functional as F

class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # (Wh, Ww)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # 这里踩过坑:qkv的bias必须保留,否则小模型收敛慢
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        
        # 相对位置偏置表,别写成nn.ParameterList,直接nn.Parameter
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
        
        # 计算相对位置索引
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

    def forward(self, x, mask=None):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # 别写成qkv.unbind(0),显存会炸
        
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
        
        # 相对位置偏置
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)
        
        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        return x

3.2 窗口划分与合并:这里最容易出bug

def window_partition(x, window_size):
    # x: (B, H, W, C)
    B, H, W, C = x.shape
    # 别这样写:直接view会报错,因为H和W可能不能被window_size整除
    # 正确做法:先pad
    pad_h = (window_size - H % window_size) % window_size
    pad_w = (window_size - W % window_size) % window_size
    if pad_h > 0 or pad_w > 0:
        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))  # 注意pad顺序:左、右、上、下
    H_pad, W_pad = x.shape[1], x.shape[2]
    x = x.view(B, H_pad // window_size, window_size, W_pad // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C)
    return windows, (H_pad, W_pad)

def window_reverse(windows, window_size, H, W, pad_h, pad_w):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    # 去掉padding
    if pad_h > 0 or pad_w > 0:
        x = x[:, :H-pad_h, :W-pad_w, :].contiguous()
    return x

3.3 Swin Transformer Block主体

class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        
        # 这里踩过坑:shift_size不能大于window_size
        if min(self.input_resolution) <= self.window_size:
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, window_size=(self.window_size, self.window_size), num_heads=num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, 4 * dim),
            nn.GELU(),
            nn.Linear(4 * dim, dim)
        )
        
        # 计算attention mask
        if self.shift_size > 0:
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            mask_windows, _ = window_partition(img_mask, self.window_size)
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None
        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        # x: (B, C, H, W)
        B, C, H, W = x.shape
        shortcut = x
        x = x.flatten(2).transpose(1, 2)  # (B, H*W, C)
        x = self.norm1(x)
        x = x.view(B, H, W, C)
        
        # 循环移位
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
        
        # 划分窗口
        windows, (H_pad, W_pad) = window_partition(shifted_x, self.window_size)
        
        # 窗口注意力
        attn_windows = self.attn(windows, mask=self.attn_mask)
        
        # 合并窗口
        shifted_x = window_reverse(attn_windows, self.window_size, H_pad, W_pad, 
                                   pad_h=H_pad-H, pad_w=W_pad-W)
        
        # 反向移位
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        
        x = x.view(B, H*W, C)
        x = self.norm2(x)
        x = self.mlp(x)
        x = x.transpose(1, 2).view(B, C, H, W)
        return x + shortcut

3.4 替换YOLOv11 Backbone的最后两层

找到ultralytics/nn/modules/block.py中的C3k2类,在ultralytics/nn/modules/backbone.py中定位到最后两个stage。假设原始Backbone结构如下:

Stage3: C3k2(256, 512, 3, True)  # 输出20x20
Stage4: C3k2(512, 512, 3, True)  # 输出10x10

替换为:

# 在backbone的__init__中
self.stage3 = SwinTransformerBlock(
    dim=512, 
    input_resolution=(20, 20), 
    num_heads=8, 
    window_size=4,  # 20/4=5,整除
    shift_size=2    # 窗口大小的一半
)
self.stage4 = SwinTransformerBlock(
    dim=512, 
    input_resolution=(10, 10), 
    num_heads=8, 
    window_size=4,  # 这里注意:10不能被4整除,但SwinBlock内部会处理padding
    shift_size=2
)

关键修改点:在forward中,需要把C3k2的输入输出通道对齐。原始C3k2的输入是256通道输出512,而SwinBlock要求输入输出通道一致。所以需要在Stage3前加一个1x1卷积升维:

self.stage3_conv = nn.Conv2d(256, 512, 1)

四、消融实验:窗口大小和移位策略的影响

在PCB缺陷数据集上(训练集5000张,测试集1000张,8类缺陷),用YOLOv11s做基准,替换最后两层Swin Block,训练100轮,输入640x640:

配置mAP@0.5mAP@0.5:0.95参数量推理速度(ms)
原始C3k278.3%52.1%9.2M2.1
Swin Block (window=4, shift=0)79.8%53.6%10.1M2.8
Swin Block (window=4, shift=2)81.7%55.4%10.1M3.0
Swin Block (window=7, shift=3)80.2%53.9%10.1M3.5
Swin Block (window=8, shift=4)79.5%52.8%10.1M3.8

结论

  • 窗口大小4+移位2效果最好,因为20x20特征图用4x4窗口刚好5x5个窗口,移位后能覆盖所有空间位置
  • 窗口太大(7或8)反而下降,因为小特征图下窗口内像素太少,自注意力退化成平均池化
  • 推理速度增加约1ms,但mAP涨3.4个点,性价比很高

五、个人经验:三个容易忽略的细节

  1. LayerNorm的位置:Swin Block的LayerNorm放在attention和MLP之前(pre-norm),而C3k2用的是BatchNorm。替换后如果发现训练不稳定,检查一下BN层的running_mean是否被冻结——我遇到过因为model.train()没正确设置导致BN统计量不更新,mAP直接掉5个点。

  2. 梯度检查点:Swin Block的显存占用比C3k2高30%左右,如果batch size设8爆显存,可以在SwinBlock的forward里加torch.utils.checkpoint.checkpoint。但注意:checkpoint不支持torch.roll操作,需要把移位部分单独拎出来。

  3. 混合精度训练:Swin Block里的softmax在fp16下容易溢出,建议在WindowAttention.forward里把attn转成fp32计算,再转回fp16。代码里加一行attn = attn.float().softmax(dim=-1).half()就能解决。

最后说句实在话:Swin Block替换Backbone后两层不是万能药。如果你的数据集里目标尺寸都很大(比如行人检测),窗口注意力带来的提升可能不到1个点。但如果你做的是小目标检测(比如遥感图像、工业缺陷),这个改动值得一试——至少在我的三个项目里都稳定涨点。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值