1. 医学图像分割的“老大难”问题:为什么常规损失函数会失灵?
如果你刚开始接触医学图像分割,比如想从CT或者MRI图像里把肿瘤、器官或者细胞给圈出来,你可能会兴冲冲地抄起一个经典的二分类交叉熵损失(Binary Cross Entropy, BCE)就开始训练。结果跑了几轮,一看指标,准确率(Accuracy)高得吓人,心里正美呢,结果把预测图可视化出来一看,傻眼了——模型啥也没分割出来,预测的全是背景!
这可不是模型偷懒,而是医学图像分割任务本身自带的一个“坑”:极端的类别不平衡。想象一下,一张512x512的肺部CT切片,真正的病灶区域可能只有几十个像素点,而背景(正常的肺组织、空气等)占据了图像的绝大部分。对于BCE这种“老实人”损失函数来说,它追求的是所有像素点的平均分类正确率。模型很快就会发现一个“作弊”捷径:我只要把所有像素都预测成背景,就能获得一个非常高的准确率,因为背景像素实在太多了。至于那一点点前景目标?忽略掉对整体损失的影响微乎其微。这就好比一场考试,99道题都是1+1=2,只有1道是微积分,你全答对1+1,总分也能有99分,但这并不能说明你学会了微积分。
所以,在医学图像分割这个领域,我们不能只看“全局平均分”,必须把注意力强行拉到我们关心的“重点考题”——也就是前景目标上。这就是Dice Loss、IoU Loss这类基于区域重叠度的损失函数登场的背景。它们不关心你猜对了多少背景,只关心你预测的前景区域和真实的前景区域,到底重合了多少。这个思路的转变,是解决医学图像分割核心挑战的关键第一步。我自己在早期项目里就踩过这个坑,用BCE训出来的模型看似指标漂亮,实际根本不能用,白白浪费了好几天算力。
2. Dice Loss:专治“小目标”和“不平衡”的利器
2.1 用“披萨饼”理解Dice系数
要理解Dice Loss,得先搞懂它的核心——Dice系数。咱们不用公式,先用一个生活例子。假设你和朋友合买了一个披萨(这是真实的前景区域),你用番茄酱在披萨上圈出了你认为属于你的部分(这是你的预测)。Dice系数关心的就是:你们俩公认的、重叠的那部分披萨面积有多大。
计算公式是:Dice = (2 * 重叠面积) / (你的披萨面积 + 朋友的披萨面积)。为什么要乘以2?这是为了把系数的最大值归一化到1。当你们俩圈定的区域完全一致时,重叠面积等于每个人的面积,分子是2A,分母也是2A,Dice就等于1,表示完美重合。如果你们圈的区域完全没有交集,重叠面积为0,Dice就是0。
在图像里,“面积”就是像素个数。Dice系数衡量的是预测掩膜(predicted mask)和真实掩膜(ground truth mask)之间的重叠程度。它天生就对类别不平衡不敏感,因为公式里根本没有背景像素什么事儿,分母只包含了预测的前景和真实的前景。模型想通过“全预测背景”来作弊?对不起,在Dice这里行不通,因为你的预测前景面积是0,Dice直接就是0,损失巨大。
2.2 手把手实现一个稳健的Dice Loss
理解了思想,我们来看PyTorch代码实现。这里有几个实战中容易翻车的细节,我结合自己的经验给你捋清楚。
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
def __init__(self, smooth=1e-7, reduction='mean'):
super(DiceLoss, self).__init__()
self.smooth = smooth # 防止分母为零的小常数
self.reduction = reduction
def forward(self, preds, targets):
"""
preds: 网络原始输出 (logits), 形状 [B, C, H, W] 或 [B, H, W]
targets: 真实标签,形状与preds相同,值应为0或1
"""
# 1. 激活函数:将logits转为概率
# 注意:如果网络最后一层已经是sigmoid,这里可以省略。
# 但更通用的做法是网络输出logits,在这里统一用sigmoid激活。
preds_probs = torch.sigmoid(preds)
# 2. 展平数据:忽略批次和通道维度,将所有像素视为一维向量
# 这样写兼容单通道和多通道(需分别处理各通道)
preds_flat = preds_probs.contiguous().view(-1)
targets_flat = targets.contiguous().view(-1).float() # 确保targets是float
# 3. 计算交集、并集(或总面积)


9295

被折叠的 条评论
为什么被折叠?



