Focal Loss总结

1. intro

Focal loss主要是为了解决样本不均衡问题,该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘。

2. 基本原理

2.1 二分类损失函数

L=−ylogy′−(1−y)log⁡(1−y′)={−log⁡y′y=1−log⁡(1−y′),y=0\mathrm{L}=-\mathrm{ylogy}^{\prime}-(1-y) \log \left(1-y^{\prime}\right)=\left\{\begin{array}{ll} -\log y^{\prime} & y=1 \\ -\log \left(1-y^{\prime}\right), & y=0 \end{array}\right.L=ylogy(1y)log(1y)={logylog(1y),y=1y=0

y′y^{\prime}y 是sigmoid 函数的输出, 值再 0-1 之间,。可见普通的交叉熵对于正样本而言,输出概率越大损失越小。对于负样本而言,输出概率越小则损失越小。此时的损失函数在大量简单样本的迭代过程中比较缓慢且可能无法优化至最优。

其实, sigmoid 函数输出 y′y^{\prime}y 在0-1 之间,令 y′>0.5y^{\prime} > 0.5y>0.5时属于正类,小于0.5时为负类,所以当样本属于正类时,其输出越大,即置信度越大,越可能是正类,所以损失函数越小,同理,对于负类输出概率越小,损失函数越小

2.2 Focal Loss

上文中二分类交叉熵损失为
CE(p,y)={−log⁡(p) if y=1−log⁡(1−p) otherwise \mathrm{CE}(p, y)=\left\{\begin{array}{ll} -\log (p) & \text { if } y=1 \\ -\log (1-p) & \text { otherwise } \end{array}\right.CE(p,y)={log(p)log(1p) if y=1 otherwise 

其中 p∈[0,1]p\in[0, 1]p[0,1]为模型预测正例概率值,令
pt={p if y=11−p otherwise p_{\mathrm{t}}=\left\{\begin{array}{ll} p & \text { if } y=1 \\ 1-p & \text { otherwise } \end{array}\right.pt={p1p if y=1 otherwise 

所以:
CE(p,y)=CE(pt)=−log⁡ptC E(p, y)=C E\left(p_{t}\right)=-\log p_{t}CE(p,y)=CE(pt)=logpt

Focal Loss 在交叉熵损失上增加一个调节因子(1−pt)γ\left(1-p_{t}\right)^{\gamma}(1pt)γ, FL的定义如下:
FL(pt)=−(1−pt)γlog⁡ptF L\left(p_{t}\right)=-\left(1-p_{t}\right)^{\gamma} \log p_{t}FL(pt)=(1pt)γlogpt

ptp_tpt很小时, 调节因子值很接近1, loss不受影响, 当ptp_tpt趋于1时, 调节因子接近0, 这样已经能正确分类的简单样例 loss 大大降低。超参数 γ\gammaγ 为0时,FL等价于CE,论文中发现取2时是最好的,此时若一个样本的 ptp_tpt 为0.9,其对应的CE loss是FL的100倍,可见FL相比CE可以大大降低简单例子的loss,使模型训练更关注于难例。

举例:
例如gamma为2,对于正类样本而言,预测结果为0.95肯定是简单样本,所以(1-0.95)的 γ\gammaγ 次方就会很小,这时损失函数值就变得更小。而预测概率为0.3的样本其损失相对很大。对于负类样本而言同样,预测0.1的结果应当远比预测0.7的样本损失值要小得多。对于预测概率为0.5时,损失只减少了0.25倍,所以更加关注于这种难以区分的样本。这样减少了简单样本的影响,大量预测概率很小的样本叠加起来后的效应才可能比较有效

此外,加入平衡因子alpha,用来平衡正负样本本身的比例不均:文中 α\alphaα 取0.25,即正样本要比负样本占比小,这是因为负例易分。

Lfl={−α(1−p)γlog⁡py=1−(1−α)pγlog⁡(1−p),y=0L_{f l}=\left\{\begin{array}{ll} -\alpha\left(1-p\right)^{\gamma} \log p & y=1 \\ -(1-\alpha) p^{ \gamma} \log \left(1-p\right), & y=0 \end{array}\right.Lfl={α(1p)γlogp(1α)pγlog(1p),y=1y=0

添加α\alphaα可以平衡正负样本的重要性,但是无法解决简单与困难样本的问题。

γ\gammaγ 调节简单样本权重降低的速率,当 γ\gammaγ 为0时即为交叉熵损失函数,当 γ\gammaγ 增加时,调整因子的影响也在增加。实验发现 γ\gammaγ 为2是最优。

3. 程序实现

def sigmoid_focal_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    alpha: float = -1,
    gamma: float = 2,
    reduction: str = "none",
) -> torch.Tensor:
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
        reduction: 'none' | 'mean' | 'sum'
                 'none': No reduction will be applied to the output.
                 'mean': The output will be averaged.
                 'sum': The output will be summed.
    Returns:
        Loss tensor with the reduction option applied.
    """
    p = torch.sigmoid(inputs)
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = p * targets + (1 - p) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    if reduction == "mean":
        loss = loss.mean()
    elif reduction == "sum":
        loss = loss.sum()

    return loss
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值