1. 对比学习与度量学习:为什么我们需要这些特殊的损失函数?
如果你接触过深度学习,肯定对交叉熵损失(Cross-Entropy Loss)和均方误差(MSE)这些老朋友不陌生。它们就像厨房里的盐和酱油,是解决分类、回归这些标准任务的“基础调味品”。但当你开始捣鼓人脸识别、商品推荐、图像检索这些更“高级”的菜式时,你会发现,光有盐和酱油不够用了。你需要一些更特别的“香料”,比如我们今天要聊的这五大损失函数:Triplet Loss、Ranking Loss、Margin Loss、Center Loss和InfoNCE Loss。
这些损失函数都属于一个更广阔的领域:度量学习 和 对比学习。它们的目标不是直接预测一个标签或一个数值,而是去学习一个“好的”特征空间。在这个空间里,相似的东西(比如同一个人的不同照片)会靠得很近,而不相似的东西(比如不同人的照片)则会离得很远。你可以把它想象成训练一个“审美专家”,它不关心图片里具体是什么,但能精准判断两张图片在某种意义上的“相似度”。
我刚开始做人脸识别项目的时候,就踩过直接用分类损失的坑。当时我用一个标准的卷积神经网络加上Softmax分类头,在训练集上准确率能刷到99%,但一到实际场景,识别新的人脸时效果就大幅下降。后来才明白,Softmax只关心“这张脸是不是张三”,它学习到的特征可能只是刚好能把训练集的类别分开,但并没有保证“所有张三的脸”在特征空间里都聚在一起。换句话说,它缺乏类内紧凑性和类间可分性。这时候,就需要引入我们今天要讲的这些“对比”或“度量”损失来帮忙了。
这些损失函数虽然名字各异,公式也长得不太一样,但核心思想是相通的:通过比较样本对(或三元组)在特征空间中的相对距离,来引导模型学习具有判别力的特征表示。 下面,我们就一个个拆开来看,它们具体是怎么工作的,又分别适合用在什么场景。
2. Triplet Loss:让模型学会“比差距”
Triplet Loss可能是这五个里面名气最大、应用也最广的一个了,它因谷歌2015年的FaceNet论文而广为人知。它的思想非常直观,就像我们小时候玩的“找不同”游戏,但规则是“找差距”。
2.1 核心思想与公式推导
Triplet Loss每次需要看三张图片,我们称之为一个“三元组”:一张锚点图片(Anchor),一张与锚点同类的正样本图片(Positive),一张与锚点不同类的负样本图片(Negative)。它的目标很明确:让锚点与正样本之间的距离,小于锚点与负样本之间的距离,而且最好还要小出一个“安全边际”。
用数学公式表示就是这样:
L = max( d(a, p) - d(a, n) + margin, 0 )
这里,d() 是距离函数,通常用欧氏距离。margin 是一个大于0的超参数,你可以把它理解为“容忍度”或“安全距离”。损失函数的值就是正样本距离减去负样本距离再加上margin,但如果这个值小于0,我们就取0(因为目标已经达到了)。
我举个例子你就明白了。假设我们做人脸识别,锚点 a 是张三的一张证件照,正样本 p 是张三的生活照,负样本 n 是李四的照片。模型的目标是,让 a 和 p 的特征距离尽可能小(比如0.1),同时让 a 和 n 的特征距离尽可能大(比如1.2)。如果我们设 margin=0.5,那么损失就是 0.1 - 1.2 + 0.5 = -0.6,小于0,所以最终损失为0,模型无需更新。如果 a 和 n 的距离不够远(比如只有0.4),那么损失就是 0.1 - 0.4 + 0.5 = 0.2,模型就会收到一个正的损失信号,驱动它去调整参数,把 a 和 n 推得更开。
2.2 三元组的种类与负样本挖掘
在实际训练中,不是所有三元组都是有用的。根据难度,我们可以把三元组分为三类:
- 简单三元组:负样本已经比正样本远很多了(
d(a, n) > d(a, p) + margin)。损失为0,对训练没贡献,白算。 - 困难三元组:负样本比正样本还近(
d(a, n) < d(a, p))。这是最需要纠正的错误,损失很大。 - 半困难三元组:负样本比正样本远,但没远出margin(
d(a, p) < d(a, n) < d(a, p) + margin)。损失为正但不大,是训练的重点。
这里就引出了Triplet Loss训练的一个关键技巧:负样本挖掘。如果你随机从数据里选负样本,大部分都是“简单三元组”,训练效率极低,模型学不到什么。所以,我们必须有策略地去挑选那些“困难”或“半困难”的负样本。
我常用的策略有两种:
- 离线挖掘:在一个epoch开始前,用当前模型把所有样本的特征都提取出来,预先计算好所有可能的三元组,把困难的和半困难的挑出来用于训练。这种方法效果好,但计算开销大,内存要求高。
- 在线挖掘:在每个训练批次内部动态挖掘。这是更主流的方法。比如“Batch Hard”策略:对于批次内的每个锚点,选择距离它最远的正样本(最难的正样本)和距离它最近的负样本(最难的负样本)来构造三元组。这样产生的三元组都是“硬骨头”,训练效率高。
下面是一个PyTorch中实现Batch Hard Triplet Loss的简化代码示例,你可以感受一下:
import torch
import torch.nn as nn
import torch.nn.functional as F
class BatchHardTripletLoss(nn.Module):
def __init__(self, margin=0.5):
super().__init__()
self.margin = margin
def forwar


7482

被折叠的 条评论
为什么被折叠?



