论文:Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation(arXiv:2105.05537v1)
关键词:纯 Transformer、U 形结构、shifted window、patch merging / expanding、skip connection
声明:此文作为我的学习记录,欢迎大佬们批评指正
0. 先用一句话说清 Swin-Unet 在干嘛
如果你把 U-Net 当成“分割界的老大哥”,那 Swin-Unet 干的事就是:
保留 U-Net 的 U 形骨架(编码器-解码器 + 跳连),但把里面的卷积模块基本全换成 Swin Transformer block,让网络更擅长学全局和长距离关系。
论文里也说得很直白:CNN 因为卷积“天生局部”,不太擅长显式建模全局/长距离语义交互,所以作者想试试“纯 Transformer 的 U-Net”。
1. 为什么医学分割里大家这么爱 U-Net?它强在哪、弱在哪?
1.1 U-Net 为啥强
U-Net 的经典套路:
-
编码器一路下采样:分辨率越来越小,但语义越来越强(能“看懂”这块是什么器官/组织)
-
解码器一路上采样:把语义信息放大回原图,做像素级预测
-
跳跃连接(skip connection):把编码器那些“高清细节”直接送给解码器,补回下采样丢掉的空间信息
论文在介绍 U-Net 这段几乎就是标准教科书描述:编码器提取大感受场深层特征,解码器上采样回输入分辨率,并用 skip 连接融合多尺度高分辨率特征减少空间信息损失。
1.2 但 U-Net(CNN)也有短板:全局关系学得不够“直接”
卷积是局部算子:一个 3×3 卷积一次只能看到周围 3×3 的邻居。你想让它“看远一点”,就得:
-
堆很多层(慢慢扩大感受野)
-
或者用空洞卷积、金字塔、注意力插件等
论文说得很直接:卷积的固有局部性导致 CNN 很难学习显式的全局和长距离语义交互。
2. Swin-Unet 的整体结构:就是 U-Net 的“皮”,里面换成 Swin 的“骨头”
论文结构图(Figure 1)基本把所有关键点都画出来了:编码器、瓶颈、解码器、skip connection,全部由 Swin Transformer block 组成。
2.1 先把大框架记住(非常重要)
Swin-Unet = 4 件套:
1)Patch 切块 + Linear Embedding(把图像变成 token 序列)
2)Encoder:Swin block + Patch Merging(下采样、通道增大)
3)Bottleneck:最底部再用 Swin block 做深层语义
4)Decoder:Swin block + Patch Expanding(上采样、通道减小)
5)Skip:编码器同尺度特征拿来和解码器融合
你可以把它当成一句“人话”:
编码器负责“压缩理解”,解码器负责“放大还原”,skip 负责“补细节”。
3. 第一步:Patch Partition / Embedding(Transformer 为啥非要切 patch?)
论文给的设置很具体:
-
输入图像切成 4×4 的 non-overlapping patches
-
每个 patch 原始维度:4×4×3=48(RGB 三通道)
-
再用一个 linear embedding 把 48 维投影到 C 维
通俗理解(真的很形象)
把一张图想成一张大海报,Transformer 不喜欢“像素点一个个看”,它喜欢:
-
先把海报剪成一堆小方块(patch)
-
每个小方块用一串数字描述(token)
-
之后 Transformer 的注意力就在这些 token 之间“开会”
4. Swin Transformer Block:这玩意到底在干嘛?公式到底在写啥?
论文明确说:Swin block 和普通 MSA 不同,它基于 shifted windows。每个 block 包括:
-
LayerNorm(LN)
-
Multi-head self attention
-
Residual connection(残差)
-
2-layer MLP + GELU 非线性
4.1 两个连续 block:一个 W-MSA,一个 SW-MSA(交替出现)
论文把“连续两个 block”的计算写成四条公式:
-
第一个 block:W-MSA(窗口内注意力) + MLP
-
第二个 block:SW-MSA(错位窗口注意力) + MLP
通俗理解
-
W-MSA:先“各小组内部开会”(窗口内)
-
SW-MSA:再把分组错位一下,让“隔壁组也能交流”
-
交替几轮后,信息就能逐步传到更远处(局部 → 更全局)
我之前问ai“是不是就两层?”——这里的“两层”只是一个 stage 里最基本的两个连续 block 的写法,并不是整个网络就两层。论文这里是在讲 block 的公式,不是在讲全网深度。
5. Encoder:Swin block 学特征,Patch Merging 负责下采样 + 增通道
论文一句话很关键:
-
token 分辨率是 H/4 × W/4,先过两次连续 Swin block,分辨率和通道不变
-
然后 Patch Merging:token 数量减少(2× downsampling)+ 通道变成 2 倍
-
这一套在 encoder 里重复三次
5.1 Patch Merging 到底怎么“合并”的?
论文把 Patch Merging 说得很明确:
1)把输入 patches 分成 4 份(可以理解为每个 2×2 小块)
2)把这 4 份 concatenate 拼在一起
3)拼完后通道会变成 4 倍
4)再用一个 Linear 把通道“统一”到原来的 2 倍(而不是 4 倍)
超通俗类比
你可以把它想成“压缩图片”:
-
原来 2×2 四个格子各有一堆信息
-
现在把四个格子的信息打包到一个格子里(分辨率减半)
-
打包后内容太多(通道 4 倍),所以用 Linear 再“整理压缩”一下变成 2 倍
6. Bottleneck(瓶颈)到底是啥?为什么只放两层?
“瓶颈到底是什么”,一句话就能记住:
瓶颈 = U 形结构最底部、分辨率最小、语义最浓缩的那一层(编码器和解码器交界处)。
论文这里给了一个很“工程味”的理由:
-
Transformer 堆得太深不容易收敛,所以 bottleneck 只用两个连续 Swin block
-
并且这一段分辨率和通道保持不变
通俗理解
到瓶颈这层,其实已经“压到最小了”(比如 7×7 这种级别)。你再继续下采样,就像把一张地图缩到看不清路名——信息会开始变糊。
所以瓶颈层更多是在做“深度理解总结”,而不是继续变小。
7. Decoder:核心创新之一 —— Patch Expanding(不用卷积/插值也能上采样)
论文明确说:
-
decoder 和 encoder 对称
-
用 patch expanding 做上采样
-
patch expanding 会把相邻维度 reshape 成更高分辨率(2× upsampling)
-
同时把特征维度相应减少
7.1 重点来了:Patch Expanding 的“第一层”到底做了啥?(8C→16C→4C 那段)
论文把这段写得非常具体:
输入: (W/32 × H/32 × 8C)
步骤 1:Linear(全连接变换)把通道扩大一倍:
(W/32 × H/32 × 8C) → (W/32 × H/32 × 16C)
步骤 2:rearrange(重排)把“通道里的信息搬到空间上”:
(W/32 × H/32 × 16C) → (W/16 × H/16 × 4C)
你会疑惑:为啥先变大(8C→16C),最后反而变小(→4C)?
这其实很“像素搬家”。
你把每个位置看成一个格子:
-
原来一个格子里有 8C 个“信息条”
-
你要把一个格子拆成 2×2=4 个格子(分辨率×2)
-
拆之前先把信息条整理得更“可分配”(Linear 扩到 16C)
-
拆的时候把 16C 平均分到 4 个新格子里 → 每个新格子 4C
通俗一句话:
先把信息“整理扩容”,再把它拆成 4 份分给更密的网格。
而且论文做了消融:patch expanding 比双线性插值和反卷积都更好(Synapse 上 DSC 79.13 vs 77.63/76.15)。
8. Skip Connection:Transformer 也能吃“U-Net 这套”,并且确实有效
8.1 它到底怎么 fuse(融合)的?
论文说得很明确:
1)把 encoder 的浅层特征(shallow)和 decoder 的深层上采样特征(deep)拼接 concatenate
2)拼接后通道变多了
3)再接一个 Linear,让拼接后的通道维度恢复到“和上采样特征一样”(不让网络越拼越胖)
通俗理解
-
encoder 给你“高清细节”(边界、纹理)
-
decoder 给你“整体理解”(这块是肝,那块是胃)
-
拼起来:既懂是什么,又知道边界在哪
-
Linear:把“拼太胖的通道”压回正常大小,方便后面继续处理
8.2 skip 用几条合适?论文做了消融告诉你答案
他们把 skip 数从 0、1、2、3 做了实验:
-
0 条:DSC 72.46
-
1 条:76.43
-
2 条:78.93
-
3 条:79.13
并且作者明确说:为了更鲁棒,最终用 3 条 skip(放在 1/4、1/8、1/16 分辨率处)。
这点其实挺“打脸”的:很多人会觉得 Transformer 不需要 skip,但这篇证明了 skip 对 Transformer 也很有用。
9. 输出端:最后一次 4× 上采样 + linear projection 输出分割
论文在架构概述里提到:
-
最后一个 patch expanding 做 4× 上采样把分辨率恢复到输入 (W×H)
-
然后再用 linear projection 输出像素级类别预测
通俗理解就是:
前面 decoder 一路把特征“放大回去”,最后再用一层把每个像素的特征映射成“属于哪一类”。
10. 实验部分:它到底比谁强?强在哪?
10.1 数据集与指标(论文原话)
-
Synapse:30 个病例,3779 张 CT 切片
指标:DSC(Dice)和 HD(Hausdorff Distance)2105.05537v1
-
ACDC:心脏 MRI,评估用平均 DSC
10.2 结果
Synapse 上 Swin-Unet:DSC 79.13,HD 21.55
ACDC 上 Swin-Unet:平均 DSC 90.00
论文还强调:虽然 DSC 提升不算爆炸,但 HD(边界误差)提升明显,说明边界预测更好。
(说人话:分割任务里“边界好不好”很关键,HD 低往往意味着轮廓更准。)
10.3 实现细节(对复现党很有用)
论文给了训练设置:
-
输入 224×224
-
patch size = 4
-
PyTorch 1.7.0
-
V100 32GB
-
ImageNet 预训练权重初始化
-
batch size 24
11. 论文作者自己承认的局限
论文 Discussion 说得很实在:
-
Transformer 性能受预训练影响大
-
他们直接用 ImageNet 上的 Swin 权重来初始化 encoder/decoder,可能不是最优
-
未来想探索端到端的医学分割预训练
-
以及 2D → 3D 的扩展(医学很多是 3D)
12. 所以:Swin-Unet 的“创新点/贡献”到底有哪些?(按论文口径整理)
论文自己把贡献写了三条
创新点 1:纯 Transformer 的 U 形结构(不是 CNN+Transformer 拼装)
-
编码器、瓶颈、解码器都用 Swin Transformer block
-
编码器做局部到全局注意力学习,解码器上采样回输入分辨率做像素级预测
通俗点:不是“卷积当主菜,Transformer 撒点葱花”,而是“整锅都换成 Transformer”。
创新点 2:Patch Expanding(不用卷积/插值的上采样)
-
设计 patch expanding 做上采样 + 维度调整
-
并且消融证明它确实比常用上采样方式更好
通俗点:把“上采样”也做成 Transformer 的玩法,不借助 CNN 的反卷积。
创新点 3:skip connection 对 Transformer 同样有效(并且条数越多越稳)
-
skip 融合多尺度特征减少空间信息损失
-
消融:skip 从 0 到 3,性能逐步上升
通俗点:U-Net 的“祖传秘方”跳连,在 Transformer 身上也能继续发光。
13. 给小白的“背诵版速记”
你可以把 Swin-Unet 整体记成 5 句话:
1)把图切成 4×4 patch,每个 patch 变成 token(Linear 投到 C 维)
2)编码器:Swin block 学特征,Patch Merging 让分辨率减半、通道翻倍
3)瓶颈:最底部只堆两层 Swin block,尺寸和通道不变
4)解码器:Patch Expanding 用 Linear + rearrange 做 2× 上采样
5)跳连:encoder 的高清细节和 decoder 的语义特征拼起来,再 Linear 压回通道
结语
Swin-Unet 这篇论文最“值钱”的地方,不是它堆了多少层、跑了多少分,而是它把 U-Net 那套结构思维(下采样语义、上采样还原、跳连补细节)完整搬进了 Transformer,并且把上采样/下采样都做成 token 友好的形式,最后用消融实验把关键点一一验证了。
9906

被折叠的 条评论
为什么被折叠?



