Swin-Unet 论文通俗精读:把 U-Net 整套搬进 Transformer 后,作者到底“新”在哪?

DeOldify图像上色

DeOldify图像上色

图片生成
图片编辑
DeOldify

使用modelscope和gradio加载DeOldify图像上色的图像上色模型并前端推理。

论文:Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation(arXiv:2105.05537v1)
关键词:纯 Transformer、U 形结构、shifted window、patch merging / expanding、skip connection
声明:此文作为我的学习记录,欢迎大佬们批评指正

0. 先用一句话说清 Swin-Unet 在干嘛

如果你把 U-Net 当成“分割界的老大哥”,那 Swin-Unet 干的事就是:

保留 U-Net 的 U 形骨架(编码器-解码器 + 跳连),但把里面的卷积模块基本全换成 Swin Transformer block,让网络更擅长学全局和长距离关系。

论文里也说得很直白:CNN 因为卷积“天生局部”,不太擅长显式建模全局/长距离语义交互,所以作者想试试“纯 Transformer 的 U-Net”。

1. 为什么医学分割里大家这么爱 U-Net?它强在哪、弱在哪?

1.1 U-Net 为啥强

U-Net 的经典套路:

  • 编码器一路下采样:分辨率越来越小,但语义越来越强(能“看懂”这块是什么器官/组织)

  • 解码器一路上采样:把语义信息放大回原图,做像素级预测

  • 跳跃连接(skip connection):把编码器那些“高清细节”直接送给解码器,补回下采样丢掉的空间信息

论文在介绍 U-Net 这段几乎就是标准教科书描述:编码器提取大感受场深层特征,解码器上采样回输入分辨率,并用 skip 连接融合多尺度高分辨率特征减少空间信息损失。

1.2 但 U-Net(CNN)也有短板:全局关系学得不够“直接”

卷积是局部算子:一个 3×3 卷积一次只能看到周围 3×3 的邻居。你想让它“看远一点”,就得:

  • 堆很多层(慢慢扩大感受野)

  • 或者用空洞卷积、金字塔、注意力插件等

论文说得很直接:卷积的固有局部性导致 CNN 很难学习显式的全局和长距离语义交互

2. Swin-Unet 的整体结构:就是 U-Net 的“皮”,里面换成 Swin 的“骨头”

论文结构图(Figure 1)基本把所有关键点都画出来了:编码器、瓶颈、解码器、skip connection,全部由 Swin Transformer block 组成。

2.1 先把大框架记住(非常重要)

Swin-Unet = 4 件套:
1)Patch 切块 + Linear Embedding(把图像变成 token 序列)
2)Encoder:Swin block + Patch Merging(下采样、通道增大)
3)Bottleneck:最底部再用 Swin block 做深层语义
4)Decoder:Swin block + Patch Expanding(上采样、通道减小)
5)Skip:编码器同尺度特征拿来和解码器融合

你可以把它当成一句“人话”:

编码器负责“压缩理解”,解码器负责“放大还原”,skip 负责“补细节”。

3. 第一步:Patch Partition / Embedding(Transformer 为啥非要切 patch?)

论文给的设置很具体:

  • 输入图像切成 4×4 的 non-overlapping patches

  • 每个 patch 原始维度:4×4×3=48(RGB 三通道)

  • 再用一个 linear embedding 把 48 维投影到 C 维

通俗理解(真的很形象)

把一张图想成一张大海报,Transformer 不喜欢“像素点一个个看”,它喜欢:

  • 先把海报剪成一堆小方块(patch)

  • 每个小方块用一串数字描述(token)

  • 之后 Transformer 的注意力就在这些 token 之间“开会”

4. Swin Transformer Block:这玩意到底在干嘛?公式到底在写啥?

论文明确说:Swin block 和普通 MSA 不同,它基于 shifted windows。每个 block 包括:

  • LayerNorm(LN)

  • Multi-head self attention

  • Residual connection(残差)

  • 2-layer MLP + GELU 非线性

4.1 两个连续 block:一个 W-MSA,一个 SW-MSA(交替出现)

论文把“连续两个 block”的计算写成四条公式:

  • 第一个 block:W-MSA(窗口内注意力) + MLP

  • 第二个 block:SW-MSA(错位窗口注意力) + MLP

通俗理解
  • W-MSA:先“各小组内部开会”(窗口内)

  • SW-MSA:再把分组错位一下,让“隔壁组也能交流”

  • 交替几轮后,信息就能逐步传到更远处(局部 → 更全局)

我之前问ai“是不是就两层?”——这里的“两层”只是一个 stage 里最基本的两个连续 block 的写法,并不是整个网络就两层。论文这里是在讲 block 的公式,不是在讲全网深度。

5. Encoder:Swin block 学特征,Patch Merging 负责下采样 + 增通道

论文一句话很关键:

  • token 分辨率是 H/4 × W/4,先过两次连续 Swin block,分辨率和通道不变

  • 然后 Patch Merging:token 数量减少(2× downsampling)+ 通道变成 2 倍

  • 这一套在 encoder 里重复三次

5.1 Patch Merging 到底怎么“合并”的?

论文把 Patch Merging 说得很明确:

1)把输入 patches 分成 4 份(可以理解为每个 2×2 小块)
2)把这 4 份 concatenate 拼在一起
3)拼完后通道会变成 4 倍
4)再用一个 Linear 把通道“统一”到原来的 2 倍(而不是 4 倍)

超通俗类比

你可以把它想成“压缩图片”:

  • 原来 2×2 四个格子各有一堆信息

  • 现在把四个格子的信息打包到一个格子里(分辨率减半)

  • 打包后内容太多(通道 4 倍),所以用 Linear 再“整理压缩”一下变成 2 倍

6. Bottleneck(瓶颈)到底是啥?为什么只放两层?

“瓶颈到底是什么”,一句话就能记住:

瓶颈 = U 形结构最底部、分辨率最小、语义最浓缩的那一层(编码器和解码器交界处)。

论文这里给了一个很“工程味”的理由:

  • Transformer 堆得太深不容易收敛,所以 bottleneck 只用两个连续 Swin block

  • 并且这一段分辨率和通道保持不变

通俗理解

到瓶颈这层,其实已经“压到最小了”(比如 7×7 这种级别)。你再继续下采样,就像把一张地图缩到看不清路名——信息会开始变糊。

所以瓶颈层更多是在做“深度理解总结”,而不是继续变小。

7. Decoder:核心创新之一 —— Patch Expanding(不用卷积/插值也能上采样)

论文明确说:

  • decoder 和 encoder 对称

  • 用 patch expanding 做上采样

  • patch expanding 会把相邻维度 reshape 成更高分辨率(2× upsampling)

  • 同时把特征维度相应减少

7.1 重点来了:Patch Expanding 的“第一层”到底做了啥?(8C→16C→4C 那段)

论文把这段写得非常具体:

输入: (W/32 × H/32 × 8C)

步骤 1:Linear(全连接变换)把通道扩大一倍:
(W/32 × H/32 × 8C) → (W/32 × H/32 × 16C)

步骤 2:rearrange(重排)把“通道里的信息搬到空间上”:
(W/32 × H/32 × 16C) → (W/16 × H/16 × 4C)

你会疑惑:为啥先变大(8C→16C),最后反而变小(→4C)?

这其实很“像素搬家”。

你把每个位置看成一个格子:

  • 原来一个格子里有 8C 个“信息条”

  • 你要把一个格子拆成 2×2=4 个格子(分辨率×2)

  • 拆之前先把信息条整理得更“可分配”(Linear 扩到 16C)

  • 拆的时候把 16C 平均分到 4 个新格子里 → 每个新格子 4C

通俗一句话:

先把信息“整理扩容”,再把它拆成 4 份分给更密的网格。

而且论文做了消融:patch expanding 比双线性插值和反卷积都更好(Synapse 上 DSC 79.13 vs 77.63/76.15)。

8. Skip Connection:Transformer 也能吃“U-Net 这套”,并且确实有效

8.1 它到底怎么 fuse(融合)的?

论文说得很明确:

1)把 encoder 的浅层特征(shallow)和 decoder 的深层上采样特征(deep)拼接 concatenate
2)拼接后通道变多了
3)再接一个 Linear,让拼接后的通道维度恢复到“和上采样特征一样”(不让网络越拼越胖)

通俗理解
  • encoder 给你“高清细节”(边界、纹理)

  • decoder 给你“整体理解”(这块是肝,那块是胃)

  • 拼起来:既懂是什么,又知道边界在哪

  • Linear:把“拼太胖的通道”压回正常大小,方便后面继续处理

8.2 skip 用几条合适?论文做了消融告诉你答案

他们把 skip 数从 0、1、2、3 做了实验:

  • 0 条:DSC 72.46

  • 1 条:76.43

  • 2 条:78.93

  • 3 条:79.13

并且作者明确说:为了更鲁棒,最终用 3 条 skip(放在 1/4、1/8、1/16 分辨率处)。

这点其实挺“打脸”的:很多人会觉得 Transformer 不需要 skip,但这篇证明了 skip 对 Transformer 也很有用

9. 输出端:最后一次 4× 上采样 + linear projection 输出分割

论文在架构概述里提到:

  • 最后一个 patch expanding 做 4× 上采样把分辨率恢复到输入 (W×H)

  • 然后再用 linear projection 输出像素级类别预测

通俗理解就是:

前面 decoder 一路把特征“放大回去”,最后再用一层把每个像素的特征映射成“属于哪一类”。

10. 实验部分:它到底比谁强?强在哪?

10.1 数据集与指标(论文原话)

  • Synapse:30 个病例,3779 张 CT 切片

    指标:DSC(Dice)和 HD(Hausdorff Distance)

    2105.05537v1

  • ACDC:心脏 MRI,评估用平均 DSC

10.2 结果

Synapse 上 Swin-Unet:DSC 79.13,HD 21.55
ACDC 上 Swin-Unet:平均 DSC 90.00

论文还强调:虽然 DSC 提升不算爆炸,但 HD(边界误差)提升明显,说明边界预测更好。

(说人话:分割任务里“边界好不好”很关键,HD 低往往意味着轮廓更准。)

10.3 实现细节(对复现党很有用)

论文给了训练设置:

  • 输入 224×224

  • patch size = 4

  • PyTorch 1.7.0

  • V100 32GB

  • ImageNet 预训练权重初始化

  • batch size 24

11. 论文作者自己承认的局限

论文 Discussion 说得很实在:

  • Transformer 性能受预训练影响大

  • 他们直接用 ImageNet 上的 Swin 权重来初始化 encoder/decoder,可能不是最优

  • 未来想探索端到端的医学分割预训练

  • 以及 2D → 3D 的扩展(医学很多是 3D)

12. 所以:Swin-Unet 的“创新点/贡献”到底有哪些?(按论文口径整理)

论文自己把贡献写了三条

创新点 1:纯 Transformer 的 U 形结构(不是 CNN+Transformer 拼装)

  • 编码器、瓶颈、解码器都用 Swin Transformer block

  • 编码器做局部到全局注意力学习,解码器上采样回输入分辨率做像素级预测

通俗点:不是“卷积当主菜,Transformer 撒点葱花”,而是“整锅都换成 Transformer”。

创新点 2:Patch Expanding(不用卷积/插值的上采样)

  • 设计 patch expanding 做上采样 + 维度调整

  • 并且消融证明它确实比常用上采样方式更好

通俗点:把“上采样”也做成 Transformer 的玩法,不借助 CNN 的反卷积。

创新点 3:skip connection 对 Transformer 同样有效(并且条数越多越稳)

  • skip 融合多尺度特征减少空间信息损失

  • 消融:skip 从 0 到 3,性能逐步上升

通俗点:U-Net 的“祖传秘方”跳连,在 Transformer 身上也能继续发光。

13. 给小白的“背诵版速记”

你可以把 Swin-Unet 整体记成 5 句话:

1)把图切成 4×4 patch,每个 patch 变成 token(Linear 投到 C 维)
2)编码器:Swin block 学特征,Patch Merging 让分辨率减半、通道翻倍
3)瓶颈:最底部只堆两层 Swin block,尺寸和通道不变
4)解码器:Patch Expanding 用 Linear + rearrange 做 2× 上采样
5)跳连:encoder 的高清细节和 decoder 的语义特征拼起来,再 Linear 压回通道

结语

Swin-Unet 这篇论文最“值钱”的地方,不是它堆了多少层、跑了多少分,而是它把 U-Net 那套结构思维(下采样语义、上采样还原、跳连补细节)完整搬进了 Transformer,并且把上采样/下采样都做成 token 友好的形式,最后用消融实验把关键点一一验证了。

您可能感兴趣的与本文相关的镜像

DeOldify图像上色

DeOldify图像上色

图片生成
图片编辑
DeOldify

使用modelscope和gradio加载DeOldify图像上色的图像上色模型并前端推理。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值