论文名称:Medical SAM adapter: Adapting segment anything model for medical image
segmentation 发表于《Medical Image Analysis》2025
作者:Junde Wu a , Ziyue Wang b , Mingxuan Hong a , Wei Ji c , Huazhu Fu d , Yanwu Xu e ,
Min Xu f,g , Yueming Jin a,b,*
https://doi.org/10.1016/j.media.2025.103547
目录
Existing Problem:
-
SAM在医学图像分割领域表现不好,因为缺乏具体医学知识,包括低图像对比度,组织边界模糊,微小病变区域等。
-
MedSAM解决此问题是fully fine-tune原始的SAM模型,这样计算成本和内存占用都很高。且fully fine-tune真的有必要嘛?因为先前的研究表明,预训练的视觉模型对医学图像具有很强的可移植性。
-
Adaption是一种高效参数微调技术,在NLP领域使用广泛,但面临两个挑战:(如果不懂的化还要看看Adaption)
(1) 医学图像模态通常是3D的,如CT,MRI,SAM可以被应用在3D图像的每个切片上来获取最终结果,但没有考虑到3D医学图像分割中固有的紧密的体积相关性——如何将2D SAM 适配到3D医学图像分割上;
(2) Adaption应用在CV领域的研究有限,特别是交互式模型(如SAM,prompt很重要),如何融合Adaption和prompt还没有被探索。
注:Adaption的主要思路就是将带有部分参数的adapter模块插入到原始模型中,在保持大型预训练模型不变的情况下,仅更新少量额外的adapter参数。
本文主要内容:
针对以上问题,本文提出Med-SA,是探索SAM在medical domain首批工作。
几个主要的点:
-
使用轻量且有效的adaptation技术,而不是fine-tune SAM,将特定领域的医学知识融入分割模型—— 参数高效微调PEFT技术Adaption 来微调预训练的SAM(小工作量,仅更新SAM参数(13M)的2%)
-
提出空间深度转置 Space-Depth Transpose(SD-Trans)来适配2D和3D场景
-
提出超提示适配器 Hyper-Prompting Adapter (HyP-Adpt) 以实现提示条件下的适配
SD-Trans
解决医学图像形态问题(3D图像)
作用:将输入嵌入的spatial维度转置到depth维度,使得相同的self-attention模块能够在给定不同输入时处理不同维度的信息(等会看看代码,是什么意思)
HyP-Adpt
解决基于Prompt的适配问题
作用:实现 prompt-conditioned Adaption。在该Adapter中,我们使用视觉prompt来生成一系列权重,这些权重可高效应用于Adapter模块,促进广泛且深入的prompt-Adapter交互。(看看什么是视觉Prompt)
相关概念
参数高效微调 PEFT
-
全称:Parameter-efficent fine-tuning
-
相比于fully fine-tuning全量微调,PEFT冻结大部分参数,只更新少量参数(通常少于5%)
-
相关研究表明(Bitfit),PEFT比全量微调效果好,因为PEFT能避免catastrophic forgetting灾难性遗忘,且在跨域场景中泛化性更强
-
在所有PEFT方法中,Adaption是微调大模型用于下游任务的有效工具。如(arXiv preprint arXiv:
2205.13535,arXiv preprint arXiv:2203.16329. )、
Method

Med-SA architecture
-
冻结预训练SAM的参数
-
加入Adapter模块,集成到指定位置
-
Adapter起到瓶颈模型的作用(看图1b),由下采样,ReLU激活,上采样组成
-
下采样使用简单的MLP将输入压缩到低维度
-
上采样使用另一个MLP将压缩后的嵌入拓展回原始维度
-
-
在SAM的encoder中,我们在每个标准的ViT模块中放置2个Adapter(看图1b)
-
第一个Adapter,在Multi-head Attention后面
-
Multi-head Attention在ViT中起关键作用,它通过增强模型捕捉输入的flattened image patches之前的关系,来捕捉全局和局部依赖。
-
-
第二个Adapter,在后面的残差路径中。
-
MLP在ViT中也很重要,可防止ViT生成秩为1的矩阵,避免输出退化。所以将Adapter并行集成到MLP模块中,Adapter的特定任务特征对固定分支的通用特征起到补充作用,丰富整体特征
-
我们在Adapter之后引入了缩放因子s,用于平衡这两类特征。(作者关于Adapter位置和缩放因子,做了消融实验,见论文4.7)
-
-
-
在SAM的decoder中,我们在每个标准的ViT模块中放置3个Adapter(看图1d)
-
第一个Adapter,用于整合prompt嵌入,引入Hyper-Prompting,超提示适配器(HyP-Adpt)
-
第二个Adapter,与encoder一致,用于适配经残差加强的嵌入,与MLP并行
-
第三个Adapter,在image嵌入到prompt 交叉注意力的残差连接之后(还需要再理解),并在此之后又连接了另一个残差连接和层归一化,以输出最终结果。
-
SD-Trans architecture
-
本文提出空间深度专转置 Space-Depth Transpose(SD-Trans)来适配3D场景,考虑每个slice之间的联系

-
如图1c所示,对于每一个block,我们将attention操作分成两个分支:space分支和depth分支(3D图像多一个深度维度)。对于一个带有深度D的3D样本
-
space分支:输入D×N×L到multi-head attention,N是embeddings的数量number,L是embeddings的长度long(N是样本数,L是特征数?)。D 对应操作的次数,使得交互能够在 N×L 上进行,从而捕捉空间相关性
-
depth分支:我们转置输入矩阵得到 N×D×L,输入multi-head attention,尽管使用了相同的注意力机制,但现在的交互是在 D×L 上进行的,这使得深度相关性的学习成为可能。最后,我们将深度分支的结果转置回其原始形状,并将它们添加到空间分支的输出中,从而整合了深度信息。
-
HyP-Adpt architecture
-
将adaptation应用在交互式视觉模型的探索还很少,且自然场景和医学领域的交互行为差距很大
-
本文提出Hyper-Prompting Adapter (HyP-Adpt),用于实现prompt-conditioned adaptation,如图1d所示

-
HyP-Adpt的详细结构如图2所示

-
HyP-Adpt的核心思想是利用prompt embedding为adapter生成用于知识条件化的权重。具体而言,我们利用投影和reshaping操作,从prompt embedding中生成一系列权重图。然后,这些权重图通过矩阵相乘被直接应用到adapter上
-
这种方法能够实现广泛且深入的特征级交互,同时与生成整个网络相比,还能显著减少所需的参数数量。
-
形式上,我们在adapter的缩减嵌入 edown上进行超提示操作,此时提示信息(point,box)被拼接并缩减为prompt embedding eprompt. 然后使用eprompt来生成权重图序列
-
以其中一个为例:
-
其中Re表示reshape,M表示MLP层将 eprompt∈RN×L 投影成 eprompt∈RN×Lin*Lout ,其中 * 是数值乘法,Lin是 edown的长度,Lout是输出的目标长度。(为了后续和edown的维度匹配)
-
随后我们将eprompt从一维嵌入reshape成了二维权重wprompt∈RN×Lin*Lout,并将其应用于edown,可以表示为:
-
其中⊗表示矩阵乘积,我们沿着长度维度对元素进行归一化,之后应用 ReLU 激活函数。
-
作者为Hyper-prompt设置了3层,每个权重都由单独的MLP层进行投影。HyP-Adpt 有助于根据提示信息调整参数,并且能更灵活地适应不同的模态和下游任务。
Prompt 生成策略
-
Med-SA仅采用sparse prompt:click prompt 和 bounding box
-
对于bounding box,采用和SAM相同的方式
-
对于click prompt,因为SAM论文仅给了有限的生成策略,本文设计了自己的生成方法。
-
使用positive click表示前景区域,negative click表示背景区域
-
结合随机和迭代点击采样策略来训练模型。先通过随机采样进行prompt初始化,然后利用迭代采样程序加入一些click。(看一下代码什么是迭代采样程序),这种采样策略模拟了与用户真实的交互,因为每次点击都被放置在网络利用之前的click集合生成的预测的错误区域中。(细节见代码)
-
-
-
本文对3D图像的每个切片进行提示,以确保分割的一致性
Experiment
-
只更新全部SAM模型参数的2%,比其他模型好(如MedSAM, nnUnet, TransUnet, UNetr, Swin-UNetr, segdiff等)
-
医学图像分割的主要挑战是3D图像处理和重叠器官的模糊边界
Datasets
-
针对不同模态的5个数据集 (包括:CT, MRI, 超声,解剖图,皮肤镜),与17个医学图像分割任务对比
-
BTCV:3D 腹部多器官数据集,CT 扫描图像,包含带标签的12种解剖结构,来自50名受试者,共有1463张轴向增强腹部临床 CT 图像,每次 CT 扫描包含 85 至 198 个切片,轴向尺寸为 512×512。
-
REFUGE2:2D 视网膜眼底图像数据集,用于两项分割任务(视盘和视杯),包含 1200 张带标签的分辨率为 2124×2056 的 RGB 图像
-
BraTS2021:3D 数据集,MRI 扫描图像。用于脑胶质母细胞瘤子区域分割,包含来自多个机构的 1280 次多参数 MRI 扫描。每次 MRI 扫描包含 155 个切片,轴向尺寸为 240×240。
-
TNMIX:2D 甲状腺结节分割基准数据集,超声图像,包含不同分辨率的图像,是一个混合数据集,其中有来自 TNSCUI(Ma 等人,2017)的 4554 个样本和来自 DDTI(Pedraza 等人,2015)的 637 个样本。
-
ISIC2019:2D 皮肤镜图像数据集,用于皮肤病变分析,包含 25331 张有标注的黑色素瘤分割标签的图像。这些数据来自不同的中心,具有不同的分辨率。
-
这些数据集总共涵盖了 17 个分割对象,Med-SA采用它们默认的训练集、验证集和测试集划分方式
-
评估标准
-
分别使用Dice score和Hausdorff Distance (HD) 作为关键指标,来评估逐像素的分割准确率和分割边界质量。
-
HD: 衡量两个点集之间的最大不匹配程度,描述了从一个集合中的点到另一个集合中最近点的最大距离。在医学图像分割中, 通常用于比较自动分割结果与真实标注(金标准)之间的边界差异。
-
HD计算示例:假设有两个点集 A 和 B,正向豪斯多夫距离 H(A,B)是 A 中每个点到 B 中最近点的距离的最大值,反向豪斯多夫距离 H(B,A)是 B 中每个点到 A 中最近点的距离的最大值,而豪斯多夫距离 H(A,B) = max(H(A,B), H(B,A))。在图像分割里,就是计算分割结果边界点集和真实边界点集之间的这种距离 。
-
-
还报告了几个分割任务的平均交并比(mIoU),以方便更好地进行比较。
实现细节
-
对2D图像,遵循SAM默认的训练设置
-
对3D图像,Med-SA使用更小的batch_size:16
-
对于 REFUGE2、TNMIX 和 ISIC 数据集,模型训练 40 个epochs。对于三维 BTCV 和 BraTS 数据集,训练延长至 60 个epochs。
-
所有实验均在 PyTorch 平台上实现,并在 4 块 NVIDIA A100 GPU 上进行训练和测试。
-
2D图像被调整成1024×1024,3D图像的原始体数据被调整为128×128×128,不进行裁剪/填充。
-
对于交互式模型的prompt设置,遵循以往研究进行4种不同的prompt设置
-
一个随机的正点,记为 “1 - point”
-
三个正点,记为 “3 - points”
-
与目标重叠 50% 的边界框,记为 “BBox 0.5”
-
与目标重叠 75% 的边界框,记为 “BBox 0.75”
-
-
为确保分割的一致性,我们为三维图像的每个切片单独提供提示。为了最小化随机提示带来的随机性,我们在所有数据集上重复进行了五次 Med - SA 实验,文中展示平均结果。
Github代码详解
-
下载并解压后,cd进入该目录,安装环境:conda env create -f environment.yml
-
conda activate sam_adapt 进入环境
-
下载预训练权重(SAM的sam_vit_b_01ec64.pth)放入./checkpoint/sam文件夹
-
下载给的案例数据集ISIC和csv文件放在data/isic目录下
-
即可运行(一张单卡16G的情况,batchsize=2, image_size调成了512来跑的isic)
Adaptation模块的代码
# 代码的models-common-adapter.py内
class Adapter(nn.Module):
def __init__(self, D_features, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True):
super().__init__()
self.skip_connect = skip_connect
D_hidden_features = int(D_features * mlp_ratio) # 隐藏层维度计算
self.act = act_layer()
self.D_fc1 = nn.Linear(D_features, D_hidden_features)
self.D_fc2 = nn.Linear(D_hidden_features, D_features)
def forward(self, x):
# x is (BT, HW+1, D)
xs = self.D_fc1(x) # 降维
xs = self.act(xs) # 激活
xs = self.D_fc2(xs) # 还原回原始维度
if self.skip_connect: # 如果有残差结构
x = x + xs
else:
x = xs
return x
加入Adapter后的ViT block
在image_encoder内,如果是sam_adpt模型,就使用AdapterBlock
# 对应图1b,1c,在代码的adapter_block.py内
class AdapterBlock(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
args,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
scale: float = 0.5,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.args = args
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
if(args.mid_dim != None):
adapter_dim = args.mid_dim
else:
adapter_dim = dim
self.MLP_Adapter = Adapter(adapter_dim, skip_connect=False) # MLP-adapter, no skip connection
self.Space_Adapter = Adapter(adapter_dim) # with skip connection
self.scale = scale
self.Depth_Adapter = Adapter(adapter_dim, skip_connect=False) # no skip connection
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x # 保留原始x,后面用作残差连接
# Window partition 窗口划分
if self.window_size > 0: # 如果窗口大小大于0 则将特征图划分为多个不重叠的窗口
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
## 3d branch 3D分支处理
if self.args.thd:
hh, ww = x.shape[1], x.shape[2]
if self.args.chunk: # 确定深度维度大小
depth = self.args.chunk
else:
depth = x.shape[0]
# 重新排列,将depth维度单独分离出来
xd = rearrange(x, '(b d) h w c -> (b h w) d c ', d=depth)
# xd = rearrange(xd, '(b d) n c -> (b n) d c', d=self.in_chans)
xd = self.norm1(xd) # 归一化
dh, _ = closest_numbers(depth) # 获取最接近的两个数,用于将深度维度重新排列成二维 dh*dw=depth
xd = rearrange(xd, 'bhw (dh dw) c -> bhw dh dw c', dh= dh)
xd = self.Depth_Adapter(self.attn(xd))
xd = rearrange(xd, '(b n) dh dw c ->(b dh dw) n c', n= hh * ww )
x = self.norm1(x) # 归一化
x = self.attn(x) # 注意力机制,给输入x乘上注意力系数 图中的 multi-head attention
x = self.Space_Adapter(x) # 图中第一个 Adapter
if self.args.thd: # 如果是3D图片
xd = rearrange(xd, 'b (hh ww) c -> b hh ww c', hh= hh )
x = x + xd
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x # resnet
# MLP块
xn = self.norm2(x) # 归一化
# 论文中并行的 MLP 和 Adapter self.scale是缩放因子
x = x + self.mlp(xn) + self.scale * self.MLP_Adapter(xn) # 第二个Adapter
return x
冻结参数操作
'''Train,在function.py文件内的def train_sam内'''
# 控制哪些参数参与训练、哪些参数被冻结(只训练Adapter模块,冻结其他权重)
if args.mod == 'sam_adpt':
for n, value in net.image_encoder.named_parameters():
if "Adapter" not in n: # 冻结非 Adapter 参数
value.requires_grad = False
else: # 解冻 Adapter 参数
value.requires_grad = True
decoder
decoder的代码中没有看到对应图1d的实现,请大佬们指教,以下是GitHub中的代码:
class MaskDecoder(nn.Module):
def __init__(
self,
*,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int,
activation: Type[nn.Module] = nn.GELU,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
) -> None:
"""
Predicts masks given an image and prompt embeddings, using a
transformer architecture.
Arguments:
transformer_dim (int): the channel dimension of the transformer
transformer (nn.Module): the transformer used to predict masks
num_multimask_outputs (int): the number of masks to predict
when disambiguating masks
activation (nn.Module): the type of activation to use when
upscaling masks
iou_head_depth (int): the depth of the MLP used to predict
mask quality
iou_head_hidden_dim (int): the hidden dimension of the MLP
used to predict mask quality
"""
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer
self.iou_token = nn.Embedding(1, transformer_dim)
self.num_multimask_outputs = num_multimask_outputs
self.num_mask_tokens = max(4, num_multimask_outputs) # for backward compatibility on loading checkpoints
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
self.output_hypernetworks_mlps = nn.ModuleList(
[
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
for i in range(self.num_mask_tokens)
]
)
self.iou_prediction_head = MLP(
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
)
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single
mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
"""
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)
# Select the correct mask or masks for output
mask_slice = slice(0, self.num_multimask_outputs)
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]
# Prepare output
return masks, iou_pred
记录
报错
安装环境:conda env create -f environment.yml 报错pip failed,因为找不到torch==1.12.1+cu113
去pytorch官网找了对应的版本安装
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
然后更新依赖包
conda env update -f environment.yml --prune # --prune 会移除环境中多余的包
还是报错找不到torch==1.12.1+cu113, 重新使用pip安装后,再更新依赖包成功
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
conda env update -f environment.yml --prune # --prune 会移除环境中多余的包

481

被折叠的 条评论
为什么被折叠?



