Sampling-bias-corrected neural modeling 论文阅读

  • 论文Sampling-Bias-Corrected Neural Modeling for Large Corpus Item Recommendations阅读
  • 有部分参考https://blog.csdn.net/whgyxy/article/details/123671147;
  • 用pytorch进行了初步复现,并用MovieLens数据集小样本跑了一下,后续准备增加使用GCN的情况。详见https://github.com/gys-syg/Reproduction-of-Sampling-bias-correction,个人学习记录,有问题欢迎指出交流;
  • 复现时参考了https://github.com/LongmaoTeamTf/deep_recommenders,不过这个是用的tensorflow。

1 前置知识

1.1 batch

  • 在深度学习中,batch(小批次) 是指在一次前向传播(forward pass)和反向传播(backward pass)中所使用的一组样本。
    例如:
    • 你的训练数据集可能有几百万个样本;
    • 不可能一次性全部放入显存训练;
    • 所以你每次取其中一小部分(比如 512、1024 条样本),组成一个 batch;
    • 用这一批样本计算损失(loss)并更新参数。
      暂时无法在飞书文档外展示此内容

1.2 数据流

在该论文中,流数据指的是动态生成、持续更新且无固定边界的训练数据,核心特征是数据分布随时间变化,且物品库(如视频、页面等推荐对象)会不断新增或调整,不存在静态、固定的物品词汇表。
具体而言,论文中流数据的场景体现为:

  1. 数据生成的连续性:以YouTube推荐场景为例,用户的点击、观看等反馈数据每日都会大量产生,训练数据按天组织,模型需按从旧到新的顺序持续消费这些新增数据,而非依赖一个静态的、一次性的数据集。
  2. 物品分布的动态性:物品(如YouTube视频)的流行度会随时间变化(例如新视频上传、旧视频热度下降),导致物品的采样概率分布持续调整,模型需自适应这种分布偏移,而非假设物品分布固定不变。
  3. 无固定词汇表约束:传统采样方法常依赖预先定义的物品词汇表(即已知所有待推荐物品),但流数据场景下物品会不断新增(如YouTube每日新增大量视频),无法预先确定完整物品集合,因此论文提出的流式物品频率估计算法无需固定词汇表,可通过哈希数组动态记录新增物品的采样信息。
    论文针对流数据的这些特点,设计了适配的解决方案:通过维护哈希数组与全局步骤的流式频率估计算法,实时更新物品采样概率,确保模型在数据持续流入、分布动态变化的场景下,仍能实现无偏的采样偏差校正。

2 论文内容

2.1 摘要

  • 推荐系统通常处理数据稀疏性和长尾分布的方式,是对物料的内容特征进行embedding得到向量表示。
  • 除了使用基于矩阵分解的内容感知系统,论文考虑双塔结构,其中一个塔是编码了很多内容特征的物料塔。训练这类双塔模型的常规方法,是对基于batch内负样本计算的损失函数进行优化 —— 这些负样本是从batch中随机采样得到的物品。然而,batch内损失会受到采样偏差的影响,这可能损害模型性能,在物品分布高度倾斜(如长尾分布)的情况下尤为明显。
  • 论文提出了一种从流数据中估计物品频率的新算法。通过理论分析与仿真实验,并证明所提算法无需固定的物品词汇表即可运行,能够实现无偏估计,并且可以自适应物品分布的变化。随后,论文将这种经采样偏差校正的建模方法应用于构建面向 YouTube 推荐的大规模神经检索系统。该系统已部署上线,用于从包含数千万个视频的语料库中检索个性化推荐内容。— 即预测物品分布频率,并在负采样中增加频率较高物品的权重!
  • 论文通过在两个真实世界数据集上开展离线实验,验证了采样偏差校正的有效性;同时还进行了在线 A/B 测试,结果表明该神经检索系统能够提升 YouTube 的推荐质量。

2.2 Introduction

  1. “Batch Softmax Optimization” 为训练双塔 DNN 的常规方法。其核心逻辑是,在计算物品概率时,无需遍历全部物品库,而是仅基于随机小批次(minibatch)中的所有物品进行计算 —— 这种方式能大幅降低大规模物品库(如 YouTube 的千万级视频库)下的训练复杂度;
  2. 传统采样方法常依赖预先定义的物品词汇表(即已知所有待推荐物品),但流数据场景下物品会不断新增(如 YouTube 每日新增大量视频),无法预先确定完整物品集合;
  3. 但Batch Softmax Optimization 存在显著的采样偏差。这一偏差源于物品的幂律分布(长尾分布)特性 —— 热门物品因在数据中出现频率更高,更易被选入训练批次并作为 “负样本” 参与损失计算,导致模型过度惩罚热门物品、过度偏向热门物品的推荐(按经验来说,如果不用batch,而是用简单随机抽样选取负样本的话,抽样概率应为(点击次数)^0.75,而batch抽样概率直接(相当于)正比于点击次数,这样显然会过度惩罚热门作品),同时忽略长尾物品(反馈少但可能匹配用户兴趣的物品),最终严重限制模型性能;
  4. 本文贡献:
    • 流式物品频率估计算法(Streaming Frequency Estimation):
      针对传统采样方法依赖 “固定物品词汇表”、无法适配流式数据(物品库动态更新、分布随时间变化)的痛点,论文提出一种全新的物品频率估计算法。该算法无需预设固定物品集合,通过维护哈希数组记录物品的最新采样步骤与采样间隔,结合指数移动平均(EMA)实现物品采样概率的无偏估计;同时通过数学推导证明,随着采样步数增加,估计偏差会趋近于 0,且可通过调整学习率平衡 “分布适应性” 与 “估计方差”。此外,仿真实验验证了该算法能有效捕捉数据分布的动态变化(如物品热度骤升 / 骤降),为后续偏差校正提供精准的频率依据;
    • 通用大规模检索系统建模框架(Modeling Framework):
      论文构建了一套可复用的大规模检索系统建模框架,核心是将 “流式估计的物品频率” 融入批次 Softmax 的交叉熵损失中,实现采样偏差校正。传统双塔模型因 “批次内负样本采样偏差”(热门物品更易入批、过度被惩罚)导致性能受限,而该框架通过对物品 logit(用户 - 物品嵌入点积)按采样概率进行校正,抵消偏差影响;同时支持嵌入归一化、温度参数调优等优化手段,兼顾模型训练稳定性与检索准确性。该框架不依赖特定物品类型(如视频、网页),可直接迁移至各类大规模语料推荐场景;

2.3 相关工作

  • 简单介绍softmax和双塔模型;

2.4 模型框架

  1. 推荐任务被抽象为一组 query(查询) 和 item(候选物料):
    {xi}i=1N, {yj}j=1M \{x_i\}_{i=1}^N ,\ \{y_j\}_{j=1}^M{xi}i=1N, {yj}j=1M
    每个 x 或 y 都是高维特征(用户 + 上下文、视频 + 内容等)。
    目标:给定 x 检索最相关的一小部分 y。
  2. 双塔模型结构
    论文定义两个可学习的嵌入函数:
    u(x;θ), v(y;θ)∈Rk u(x;\theta),\ v(y;\theta) \in \mathbb{R}^k u(x;θ), v(y;θ)Rk
    它们通常是两个 DNN 塔(左塔处理 user/context ,右塔处理 item 特征)。
    得分函数:
    s(x,y)=⟨u(x;θ),v(y;θ)⟩s(x,y)=\langle u(x;\theta),v(y;\theta)\rangles(x,y)=u(x;θ),v(y;θ)⟩
    即二者的内积。
  3. 损失函数推导
    训练集 T 包含样本三元组 (xi,yi,ri)(x_i,y_i,r_i)(xi,yi,ri),其中rir_iri是用户对 yiy_iyi的反馈强度(奖励/权重)。
    若把推荐视为一个多分类问题(每个 item 是一个类别):
    P(y∣x;θ)=es(x,y)∑j=1Mes(x,yj)P(y|x;\theta)=\frac{e^{s(x,y)}}{\sum_{j=1}^M e^{s(x,y_j)}}P(yx;θ)=j=1Mes(x,yj)es(x,y)
    损失就是加权交叉熵:
    LT(θ)=−1T∑irilog⁡P(yi∣xi;θ)L_T(\theta)=-\frac{1}{T}\sum_i r_i\log P(y_i|x_i;\theta) LT(θ)=T1irilogP(yixi;θ)
    但当 MMM(候选物料总数)巨大时,分母无法计算。
  4. Batch Softmax 近似
    实际只取一个小 batch 内的 B 个 item 计算 softmax:
    PB(yi∣xi;θ)=es(xi,yi)∑j=1Bes(xi,yj)P_B(y_i|x_i;\theta)= \frac{e^{s(x_i,y_i)}}{\sum_{j=1}^B e^{s(x_i,y_j)}}PB(yixi;θ)=j=1Bes(xi,yj)es(xi,yi)
    这种“in-batch negative sampling”效率高,但引入采样偏差——热门 item 更常出现在 batch 中,被过度当作负样本惩罚。
  5. 采样偏差修正(Bias Correction)
    论文借鉴 sampled softmax 的 log-Q 校正思想:
    sc(xi,yj)=s(xi,yj)−log⁡pjs_c(x_i,y_j)=s(x_i,y_j)-\log p_jsc(xi,yj)=s(xi,yj)logpj
    其中pjp_jpj是 item 在随机 batch 中被采到的概率。
    改正后概率:
    PBc(yi∣xi;θ)=esc(xi,yi)∑j=1Besc(xi,yj)P_B^c(y_i|x_i;\theta)= \frac{e^{s_c(x_i,y_i)}}{\sum_{j=1}^B e^{s_c(x_i,y_j)}}PBc(yixi;θ)=j=1Besc(xi,yj)esc(xi,yi)
    相应的 batch loss:
    LB(θ)=−1B∑irilog⁡PBc(yi∣xi;θ)L_B(\theta)=-\frac{1}{B}\sum_i r_i\log P_B^c(y_i|x_i;\theta)LB(θ)=B1irilogPBc(yixi;θ)
    再用 SGD 更新 θ\thetaθ
    θ←θ−γ∇θLB(θ)\theta \leftarrow \theta - \gamma\nabla_\theta L_B(\theta)θθγθLB(θ)
  6. 算法1框架:训练流程
    1. 从数据流采样一个 batch (xi,yi,ri)(x_i,y_i,r_i)(xi,yi,ri)
    2. 用算法 2 估计每个 yiy_iyi 的采样概率 pip_ipi
    3. 计算带 bias-correction 的 batch loss LB(θ)L_B(\theta)LB(θ)
    4. 反向传播更新参数。
    注意:这不需要固定物料集合,可直接处理流式数据(vocabulary 随时间变化)。

  1. 推理阶段
    学得 u,vu,vu,v后,推理时:
    1. 计算 query 的 embedding u(x,θ)u(x,θ)u(x,θ)
    2. 在预先索引的 item embedding 集合 v(y,θ)v(y,θ)v(y,θ) 中做近邻搜索(Approximate MIPS )。
  2. 为了稳定训练,论文还做了:
    - L2L2L2 归一化:把 u,vu,vu,v 映射到单位球;
    - 温度参数 τττ:调节 softmax 平滑度,
    s(x,y)=⟨u(x),v(y)⟩τs(x,y)=\frac{\langle u(x),v(y)\rangle}{\tau}s(x,y)=τu(x),v(y)⟩

2.5 数据流频率估计

  • 要想估计item出现的概率 ppp,可以通过估计item两次出现的间隔 δ\deltaδ,也就是步长(若 δ\deltaδ 为50,那么 ppp 就为0.02)。δ\deltaδ 的更新可以通过一个简单的移动平均更新来实现,对分布有自适应性;
  • 算法结构:
    论文维护两个数组(或“表”):
    • A:记录每个 item(或哈希桶)上次出现的时间步;
    • B:记录平均间隔的滑动估计。
      如果内存不够记录所有 item,就对 item id 做哈希,把很多 item 映射到有限数量的桶。
# 初始化
A = zeros(H)  # H个哈希桶
B = ones(H) * default_interval  # 初始平均间隔估计

for step, item in enumerate(data_stream, start=1):

    h = hash(item) % H  # 把item映射到桶

    # 计算自上次出现以来的时间间隔
    delta = step - A[h]

    # 更新平均间隔的滑动估计(指数移动平均)
    B[h] = (1 - alpha) * B[h] + alpha * delta

    # 更新最近出现的时间
    A[h] = step

# 最终对任意item y
# 估计出现概率
p_hat[y] = 1 / B[hash(y) % H]

H:哈希桶数(越大越精确);
alpha:平滑参数(论文推荐 0.01);
delta:item 出现的时间间隔;
B[h]:该桶当前估计的平均间隔;
py≈1/B[h(y)]p_y \approx 1 / B[h(y)]py1/B[h(y)]

  • 每当一个 item(记作 yty_tyt)到达,有如下更新:
    {δt=t−A[h(yt)]B[h(yt)]←(1−α)B[h(yt)]+α δtA[h(yt)]←t\begin{cases} \delta_t = t - A[h(y_t)] \\ B[h(y_t)] \leftarrow (1-\alpha)B[h(y_t)] + \alpha\,\delta_t \\ A[h(y_t)] \leftarrow t \end{cases}δt=tA[h(yt)]B[h(yt)](1α)B[h(yt)]+αδtA[h(yt)]t
    估计出的采样概率:
    pyt=1B[h(yt)]p_{y_t} = \frac{1}{B[h(y_t)]}pyt=B[h(yt)]1
  • 哈希碰撞的频繁发生(一般是热门作品在哈希桶中时)会使估计的间隔偏小,p偏大,为此改进算法3:
    • 使用多个独立哈希函数h1,h2,…,hmh_1, h_2, \dots, h_mh1,h2,,hm
    • 每个维护独立的 Ai,BiA_i, B_iAi,Bi数组;
    • 对 item yyy,得到多个候选桶间隔;
    • 取其中 最大的间隔(最保守的)作为最终估计:
      B^y=max⁡iBi[hi(y)],p^y=1B^y\hat B_y = \max_i B_i[h_i(y)], \quad \hat p_y = \frac{1}{\hat B_y}B^y=imaxBi[hi(y)],p^y=B^y1
  • 直觉(多哈希+取最大间隔方法):
    • 若某个哈希桶被污染(频繁更新,B太小),会让 p^\hat pp^ 偏高;
    • 但其它独立哈希桶未必也碰到同样热门项;
    • 取最大间隔(最不频繁的)→ 能抵消“热门污染”的影响;
    • 所以这一步 修正了高估偏差。
  • 以上估计的无偏性在论文命题4.1有证明,不难。

2.6 模型概述-Youtube神经检索模型

  1. 标签选择:视频点击为正标签。另外构造了一个奖励 rir_iri 来反映不同程度的用户对视频的参与程度。例如,对于观看时间较少的点击视频, ri=0r_i = 0ri=0。另一方面,ri=1r_i = 1ri=1 表示整个视频得到了观看。奖励作为损失函数中的权重。
  2. 视频特征:包含类别特征和稠密特征。类别特征:视频ID,渠道ID,用embedding层来将这些类别特征映射为稠密向量。通常有两种类别特征,一种是类别特征只有一个取值,像视频ID特征,每个视频只有一个ID,还有一种类别特征可以取多个值,像视频主题,每个视频可以有多个主题,最后的embedding就是各embedding的加权和。对于不在词汇表中的实体,随机分配到散列通中,然后学习这些散列桶的embedding,散列桶对于模型捕获新实体非常重要。
  3. 用户特征:用用户观看的视频历史来学习用户兴趣,例如,取k个用户最近观看的视频列表。将这些历史视频ID视作词袋(Bag Of Words),通过视频ID的embedding的平均向量来表示。
  4. 在查询塔中,用户和种子视频特征在输入层进行融合,然后通过前馈神经网络进行传递。
  5. 对于相同类型的ID,embedding是共享的(可见图中下面这些交叉线条)。例如,对于视频ID特征,既作为视频侧特征,也是用户侧的观看历史特征,它们的embedding是共享的。实验发现这类特征不共享并没有显著提升。

2.7 实验

2.7.1 评估算法2、3有效性(估计数据流中item出现的概率所用算法):

  1. 实验设计思路

    • 他们假设有一个固定的物品集合 MMM,每个物品按概率 qiq_iqi 被采样。
    • 一开始用概率 qi∝i2q_i \propto i^2qii2 采样物品(偏向高索引的 item)。
    • 在训练进行到某一步(第 10,000 步)后,切换分布为 qi∝(M−1−i)2q_i \propto (M - 1 - i)^2qi(M1i)2,即分布方向反转(偏向低索引的 item)。
    • 每一步从该分布中采样一个 batch(大小为 BBB),然后用算法估计每个 item 的采样概率 pi=∣B∣×qip_i = |B| \times q_ipi=B×qi
    • 评价指标:用一个归一化的 L1L1L1 距离 衡量真实分布 {pi}\{p_i\}{pi} 与估计值 {pi^}\{\hat{p_i}\}{pi^} 的差异:
      error=12∣B∣∑i∣pi^−pi∣\text{error} = \frac{1}{2|B|} \sum_i | \hat{p_i} - p_i |error=2∣B1ipi^pi
      它也可以理解为两分布之间的总变差距离(total variation)。
  2. 实验 1:学习率 α\alphaα 的影响

    • 参数设置:M=1000,B=128M=1000,B=128M=1000B=128,哈希表大小 H=5000H=5000H=5000,初始 A=0,B=100A=0,B=100A=0B=100
    • 在不同的 α\alphaα 下运行 Algorithm 2。
    • 结果(见 Figure 4):
      • 所有曲线最终都收敛到一个稳定误差(来源于哈希碰撞与方差)。
      • 较高 α\alphaα → 适应分布变化更快,但方差更大。
      • 较低 α\alphaα → 更平滑但反应慢。
      • 与前面理论分析(Proposition 4.1)一致:α\alphaα 控制偏差与方差的平衡。
  3. 实验 2:多哈希函数的影响

    • 使用 Algorithm 3,测试不同哈希函数数量 m = 1, 2, 4。
    • 保持哈希桶总数相同(即每个数组更小但数组数量更多,保持参数总量不变)。
    • 结果(见 Figure 5):
      • 增加哈希函数数量显著降低估计误差。
      • 原因:多哈希可以缓解碰撞带来的过估或欠估,提高鲁棒性。
      • 在相同参数规模下,多哈希方案(类似 count-min sketch 的思想)更准确。

2.7.2 Wikipedia 页面链接预测任务 验证 采样偏差修正(sampling-bias correction) 的有效性

  1. 数据:目的是预测维基百科页面之间的链接,对于给定的源页面和目的页面 (x,y)(x, y)(x,y) ,label=1label = 1label=1表示源页面xxx有一个链接到目的页面yyylabel=0label = 0label=0则表示没有。每个页面由页面的特征表示,例如页面url,标题、类目的词袋表示。实验使用英文页面,包含530万页面,4.3亿链接,51万标题n-grams,40.34万类目n-grams。
  2. 模型:将链接预测视为检索召回任务,给定一个源页面,从页面集合中去召回目的页面。双塔中左右塔分别表示源页面和目的页面,输入特征embedding双塔共享,每个塔是个全连接的ReLU层,维度是[512,128][512, 128][512,128]
  3. 论文提出的基于采样修正(sampling-bias-corrected batch softmax )的batch softmax,对比的是没有采样修正(batch softmax without any correction )的batch softmax。为论证采样偏差的影响,采用均方误差(MSE)来衡量,加上正则项,loss为
    1∣Ω∣∑(xi,yi)∈Ω(⟨u(xi),v(yi)⟩−ri)2+λ⋅1∣Ωc∣∑(xi,yi)∈Ωc⟨u(xi),v(yi)⟩2\frac{1}{|\Omega|} \sum_{(x_i, y_i) \in \Omega} \left( \langle u(x_i), v(y_i) \rangle - r_i \right)^2 + \lambda \cdot \frac{1}{|\Omega^c|} \sum_{(x_i, y_i) \in \Omega^c} \langle u(x_i), v(y_i) \rangle^2∣Ω∣1(xi,yi)Ω(u(xi),v(yi)⟩ri)2+λΩc1(xi,yi)Ωcu(xi),v(yi)2
  4. 对于每一个温度值,correct - sfx都比相应的Plain - sfx有大幅度的提高。

2.7.3 Youtube 实验

  • 离线与在线实验在不同温度参数τ\tauτ下都有明显效果:

3 问题与思考

  1. 为什么不在batch抽样的时候把抽样概率直接设置为(点击次数)^0.75,来提高热门物品成为负样本的概率,而是通过优化相似度呢?
  • 论文的核心应用场景是 YouTube 等大规模推荐系统,这类场景的物品库(如视频)处于持续动态更新中 —— 新物品不断上传、旧物品热度随时间衰减,形成流式数据,且物品的 “点击次数” 等热门度指标会实时变化。
  • 若采用 “(点击次数)^0.75” 这类固定形式的抽样概率,需依赖 “预先统计所有物品的点击次数” 并定期更新概率分布,但论文明确指出,传统方法依赖的 “固定物品词汇表” 在流数据场景下不成立:新增物品无历史点击数据,无法预先计算抽样概率;而物品热度的快速变化也会导致 “基于历史点击的抽样概率” 迅速过时,进而引发新的偏差。
  • 相比之下,论文提出的 “通过优化相似度校正偏差” 方案,本质是基于实时流式估计的物品采样概率(而非静态点击次数)进行校正:通过哈希数组实时追踪物品的采样间隔,动态计算物品的真实采样概率 p^\hat{p}p^,再通过 cos(a,bi)−logp^cos(a,b_i)-log\hat{p}cos(a,bi)logp^(或论文中的 logit 校正公式)抵消偏差,无需依赖固定的热门度幂次(如 0.75),能自适应物品分布的动态变化。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值