【ACM MM‘25】联邦学习也玩“找平坦谷底”?一文读懂 FedNSAM | Junkang Liu(刘俊康)

Consistency of Local and Global Flatness for Federated Learning

【联邦学习也玩“找平坦最小值”?一文读懂 FedNSAM】 Junkang Liu(刘俊康)

在这里插入图片描述
论文摘要:

在联邦学习(Federated Learning, FL)中,多步本地更新和数据异质性通常会导致更“尖锐”的全局极小值,从而削弱全局模型的性能。为了解决这一问题,许多主流 FL 算法会在本地训练中引入 sharpness-aware minimization(SAM)。然而,在高数据异质性的场景下,本地训练中得到的“平坦性”并不一定意味着全局模型同样是平坦的。因此,仅仅在各客户端数据上最小化本地损失曲面的锐度,并不能保证 SAM 在 FL 中真正提升全局模型的泛化能力。

为了解释这一现象,我们首先引入了一个名为 平坦距离(flatness distance) 的度量,并对其进行了重新思考与理论分析。在此基础上,我们提出了一种全新的 FedNSAM 算法:通过在本地更新中引入全局 Nesterov 动量(global Nesterov momentum),来加速 SAM,并协调全局与本地平坦性的“一致性”。FedNSAM 使用全局 Nesterov 动量作为客户端对全局扰动进行本地估计以及外推(extrapolation)的方向。

从理论上,我们基于 Nesterov 外推证明了 FedNSAM 相比 FedSAM 具有更紧的收敛上界;从实验上,我们在 CNN 和 Transformer 模型上进行了大量实验,验证了 FedNSAM 在性能和效率上的显著优势。代码已开源,地址为:
https://github.com/junkangLiu0/FedNSAM
一作 Junkang Liu(刘俊康)的公开资料/简介链接。

🔗 Google 学术个人主页
https://scholar.google.com/citations?user=N7pJWIoAAAAJ&hl=zh-CN ([谷歌学术][1])


ℹ️ 关于 Junkang Liu 的简要背景

  • 他目前是 Tianjin University(天津大学)的 PhD 学生,同时也曾在 Xidian University(西安电子科技大学)攻读 MS。
  • 他的研究方向包括联邦学习(Federated Learning)、优化算法、分布式训练等。([OpenReview][2])
  • 除了最近的 FedSWA / FedMoSWA(Improving Generalization in Federated Learning …)论文之外,他也参与其他与 Federated Learning 优化相关的研究,例如 FedBCGD ,FedSWA, FedAdamW,FedNSAM等。

文章目录

🎯 一句话总结:FedNSAM = 联邦学习 + SAM + 全局 Nesterov 动量
让各个客户端不再“各找各的平坦谷底”,而是一起对齐到更“宽、更稳”的全局平坦解。


0. 论文 & 代码地址

  • 📄 论文标题:FedNSAM: Sharpness-Aware Minimization in Federated Learning with Distance of Flat Minimum
  • 🧪 官方代码仓库:FedNSAM (PyTorch 实现)
    👉 GitHub:https://github.com/junkangLiu0/FedNSAM

1. 背景:联邦学习为啥这么“难训稳”?

联邦学习(Federated Learning, FL)已经在医疗、金融、移动端等场景广泛应用:

  • 数据保留在本地,不上传原始数据,提升隐私保护;
  • 服务器只负责聚合来自各个客户端的模型更新。

然而,真正落地时会遇到几个经典痛点:

  • 🔸 数据高度非 IID:每个客户端数据分布不一样,比如:

    • 医院 A 以老年人为主;
    • 医院 B 以年轻人为主;
  • 🔸 本地多步更新(local epochs 多):每个客户端都会按自己数据“疯狂更新”;

  • 🔸 全局模型容易掉进“尖锐谷底”

    • 每个客户端都朝着“自己局部的低谷”走;
    • 聚合后的全局模型,可能刚好落到“谷与谷之间的缝隙”。

结果就是:训练集 loss 也许不难看,但测试集效果很拉胯,泛化性能差。


2. SAM 回顾:别只看一个点,要看“附近一片区域”

2.1 传统训练:只在意当前点

传统优化目标:

min ⁡ θ F ( θ ) \min_{\theta} F(\theta) θminF(θ)
只关心在某个参数点 (\theta) 上的损失值,不关心它附近区域是不是也好

这会导致:

  • 在训练集上“刚刚好”很低;
  • 但稍微扰动一下参数,loss 就爆表;
  • 这类解被称为 sharp minima(尖锐最小值),泛化表现往往不好。

2.2 SAM 的核心思想:最坏扰动下也要好

SAM(Sharpness-Aware Minimization)的目标是:

让当前点周围一小片“球形邻域”内的最坏点,也尽量损失不大。

数学形式:

min ⁡ θ max ⁡ ∣ δ ∣ 2 ≤ ρ F ( θ + δ ) \min_\theta \max_{|\delta|_2 \le \rho} F(\theta + \delta) θminδ2ρmaxF(θ+δ)
操作流程(简化版):

  1. 在当前参数 (\theta) 附近加一个扰动 δ \delta δ ∣ δ ∣ ≤ ρ |\delta| \le \rho δρ
  2. 找到这块邻域里损失最大的点
  3. 让这个“最坏点”的损失也尽可能小。

直观理解:

  • 如果一个解附近都“比较平”,就算被扰动一下,损失也不会太差;
  • 如果是尖锐谷底,一点点动就损失暴涨,SAM 会主动避开这类解。

2.3 FedSAM:把 SAM 搬进联邦学习

FedSAM 的做法:

  • 每个客户端本地训练时,用 SAM 做优化;
  • 服务器只做普通参数聚合。

数据异质性不严重的场景下,FedSAM 能显著提升全局模型的泛化性能。

但论文发现:

当数据高度非 IID 时,各客户端找到的“本地平坦最小值”彼此差距很大,
聚合后的全局模型依然不平坦,sharpness 很高,泛化仍然糟糕。

所以作者提出一个新问题:
👉 如何衡量“本地平坦最小值”与“全局模型”的距离?


3. 平坦距离(Flatness Distance):把“几何形状”量化出来

3.1 定义:客户端最小值 vs 全局模型

设第 (t) 轮训练后:

  • 客户端 (i) 做完 (K) 次本地更新后的模型为:(\theta_{i,K}^t);
  • 服务器聚合后的全局模型为:

θ t + 1 = 1 N ∑ i = 1 N θ i , K t \theta^{t+1} = \frac{1}{N}\sum_{i=1}^N \theta_{i,K}^t θt+1=N1i=1Nθi,Kt

论文定义 平坦距离(Flatness Distance)为:

Δ D = 1 N ∑ i = 1 N E ∣ θ i , K t − θ t + 1 ∣ 2 \Delta_D = \frac{1}{N} \sum_{i=1}^N \mathbb{E} |\theta_{i,K}^t - \theta^{t+1}|^2 ΔD=N1i=1NEθi,Ktθt+12

直觉理解:

  • 每个客户端都有一个“自认为平坦的谷底”;
  • 全局模型是这些谷底的“平均值”;
  • 如果这些谷底彼此靠得很近,全局模型也会在一个大致平坦的区域;
  • 如果它们离得很远,平均后全局模型可能掉在**“谷与谷之间的缝隙里”**,变得很尖锐。

在这里插入图片描述

  • 上面三张图:

    • (a) FedSAM, α=0.6:数据异质性不高时,损失曲面比较平坦,FedSAM 还能找到不错的“宽谷底”。
    • (b) FedSAM, α=0.1:数据高度非 IID 时,各客户端学到的东西差太多,聚合后全局模型落在“谷底之间的缝隙”,整体变“尖”,泛化变差。
    • © FedNSAM, α=0.1:即使在高度非 IID 下,通过全局 Nesterov 动量对齐方向,FedNSAM 让全局模型重新回到更宽、更稳的谷底。
  • 下面三张示意图:

    • (d) 低异质:各客户端的平坦区域重叠多,平均后还在平坦区。
    • (e) 高异质:平坦区分得很开,平均后掉到平坦区之外。
    • (f) FedNSAM 通过修正(全局动量),让客户端的更新方向更一致,最终的全局模型回到“共同的平坦低损失区域”里。

3.2 关键发现:平坦距离越大,全局越“尖”

在这里插入图片描述
下面是对 论文图 2(Figure 2)简洁清晰解释


🟩 图 2 的核心想表达一句话:

FedNSAM 能显著减少平坦距离、降低全局 sharpness,因此在高异质数据下比 FedSAM 泛化更好。


🟦 左图 (a):Flatness Distance(平坦距离)

  • 曲线越低 → 客户端本地最优点之间越接近 → 全局越容易落在共同的“平坦区”。

  • 现象:

    • FedSAM(绿、红、蓝三条)随着异质性增大(Dirichlet 0.1 → 0.3 → 0.6)平坦距离更大,下降更慢。
    • FedNSAM(黄色)始终最低,下降最快 → 客户端的本地平坦区域被成功“对齐”,一致性更高。

➡️ FedNSAM 明显减少客户端之间的偏移,让平均后的全局模型更稳定。


🟧 右图 (b):Global Sharpness(全局尖锐度)

  • 曲线越低 → 全局模型越平坦 → 泛化能力越好。

  • 现象:

    • FedSAM 在高度异质(Dirichlet 0.1)下 sharpness 最高,曲线下降慢。
    • FedNSAM(黄色)全程 sharpness 最低,表示找到的全局解更平坦。

➡️ FedNSAM 有效缓解了由于高度异质造成的“尖锐全局最小值”。


🟨 文字说明总结含义

  • 论文使用 CIFAR-100、100 个客户端、10% 参与率。

  • 异质性越强(Dirichlet 越小),FedSAM 性能下降越明显:

    • 40.18%(Dir=0.1)
    • 46.02%(Dir=0.3)
    • 47.83%(Dir=0.6)
  • FedNSAM 在 Dir=0.1 时达到 58.53% → 远高于 FedSAM 的 40.18%。

➡️ FedNSAM 在高度非 IID 场景下提升巨大(+18.35 个点)。


图 2 显示:

  • 左图:FedNSAM 能让客户端本地模型更一致(平坦距离更低)。
  • 右图:FedNSAM 让全局模型更加平坦(更低 sharpness)。
  • 这两点共同解释了:FedNSAM 在非 IID 数据下远优于 FedSAM。

如果你愿意,我可以帮你把这段内容整理成一段更适合 CSDN 的可直接粘贴排版版

论文通过理论分析和实验发现:

  • 数据越非 IID, Δ D \Delta_D ΔD 越大;
  • Δ D \Delta_D ΔD 越大,全局 sharpness 越高,测试精度越差

于是问题清晰了:

我们不只要在本地“找平坦”,
还要让各客户端的平坦区域尽量对齐,减小平坦距离 Δ D \Delta_D ΔD


4. FedNSAM:用全局 Nesterov 动量“对齐平坦区域”

在这里插入图片描述
在这里插入图片描述

4.1 FedNSAM 的主意一句话

在本地做 SAM 时,引入全局 Nesterov 动量
把各客户端“找平坦”的方向,对齐成尽量一致的全局趋势。

这样:

  • 本地仍然是“平坦意识”的优化(SAM 思路在);
  • 但更新方向受全局历史信息引导;
  • 聚合后,全局更容易处在大家“共同比较平坦”的区域。

在这里插入图片描述

4.2 回顾一下 Nesterov 动量(NAG)

经典 Nesterov 加速梯度(NAG)大致为:
θ t + 1 / 2 = θ t − λ v t   v t + 1 = λ v t + η ∇ F ( θ t + 1 / 2 )   θ t + 1 = θ t + v t + 1 \begin{aligned} \theta_{t+1/2} &= \theta_t - \lambda v_t \ v_{t+1} &= \lambda v_t + \eta \nabla F(\theta_{t+1/2}) \ \theta_{t+1} &= \theta_t + v_{t+1} \end{aligned} θt+1/2=θtλvt vt+1=λvt+ηF(θt+1/2) θt+1=θt+vt+1

特点:

  • 不是在当前点上算梯度,而是在**“预判一步之后”的点**上算;
  • 收敛往往比普通动量更快。

4.3 服务器端:维护全局动量

在 FedNSAM 中,服务器维护一个全局动量向量 (\mathbf{m}_t):

m t = λ m t − 1 + Δ t \mathbf{m}_t = \lambda \mathbf{m}_{t-1} + \Delta_t mt=λmt1+Δt

其中:

  • λ \lambda λ:动量系数(论文实验中约 0.85 效果最好);
  • Δ t = 1 S ∑ i ∈ S t Δ i t \Delta_t = \frac{1}{S}\sum_{i \in S_t} \Delta_i^t Δt=S1iStΔit 是本轮参与客户端的平均更新;
  • Δ i t = θ i , K t − θ i , 0 t \Delta_i^t = \theta_{i,K}^t - \theta_{i,0}^t Δit=θi,Ktθi,0t

然后服务器更新:

θ t = θ t − 1 + m t \theta_t = \theta_{t-1} + \mathbf{m}_t θt=θt1+mt

➡️ 全局模型不再只是“简单平均”,而是叠加了历史趋势的加速。


4.4 客户端本地更新:Nesterov + SAM 融合版

在第 (t) 轮,本地第 (k) 步,客户端 (i) 的更新规则简要如下:

  1. Nesterov 预判一步

θ i , k + 1 / 4 t = θ i , k t + λ m t \theta_{i,k+1/4}^t = \theta_{i,k}^t + \lambda \mathbf{m}_t θi,k+1/4t=θi,kt+λmt

  1. SAM 扰动方向:用全局动量,而不是本地梯度

    δ i , k t = ρ ⋅ − m t ∣ m t ∣ \delta_{i,k}^t = \rho \cdot \frac{-\mathbf{m}_t}{|\mathbf{m}_t|} δi,kt=ρmtmt

  2. 得到扰动后的参数

θ i , k + 1 / 2 t = θ i , k + 1 / 4 t + δ i , k t \theta_{i,k+1/2}^t = \theta_{i,k+1/4}^t + \delta_{i,k}^t θi,k+1/2t=θi,k+1/4t+δi,kt

  1. 在扰动点上计算梯度并更新

θ i , k + 1 t = θ i , k t − η ∇ F i ( θ i , k + 1 / 2 t ; ζ i ) \theta_{i,k+1}^t = \theta_{i,k}^t - \eta \nabla F_i(\theta_{i,k+1/2}^t; \zeta_i) θi,k+1t=θi,ktηFi(θi,k+1/2t;ζi)

可以看到:

  • 本地 SAM 还在做“最坏扰动下的优化”;
  • 但扰动方向和预移动方向都依赖全局动量 m t \mathbf{m}_t mt
  • 各客户端的“平坦区域搜索路径”因此被对齐。

4.5 优势总结

在这里插入图片描述

FedNSAM 相比 FedSAM / FedAvg 等方法,主要有三点优势:

  1. 更好的理论收敛上界

    • 在适当学习率和扰动半径下,FedNSAM 的收敛速度
      O ( L F T K S ( 1 − λ ) ) \mathcal{O}\left(\frac{\sqrt{LF}}{\sqrt{TKS(1-\lambda)}}\right) O(TKS(1λ) LF )
      优于 FedSAM。
  2. 更小的平坦距离 (\Delta_D)

    • 理论上证明在同等条件下,FedNSAM 对数据异质性的敏感度更低;
    • 实验上也确实观测到: Δ D \Delta_D ΔD 更小,全局 sharpness 更低。
  3. 通信成本几乎不变

    • 每个客户端只需上传本地模型差分 Δ i t \Delta_i^t Δit
    • 不像某些方法需要额外上传控制变量 / 二阶信息;
    • 非常适合实际工程系统落地。

5. 实验结果:从 CNN 到 Transformer 全面提升

在这里插入图片描述

论文在多个数据集与网络上做了系统对比,包括:

  • CIFAR-10 / CIFAR-100 / Tiny-ImageNet;
  • LeNet-5、VGG-11、ResNet-18;
  • Swin-Small / Swin-Base / ViT-Base 等。

5.1 CIFAR-100 + ResNet-18:精度直接“起飞”

设置:

  • 100 个客户端,参与率 10%;
  • 数据分布 Dirichlet-0.6(中度非 IID);
  • 本地 epoch 数 E=5。

结果(Top-1, CIFAR-100 + ResNet-18):

  • FedSAM:47.83%
  • FedNSAM66.04% 🚀(+12.21 个百分点)

同时,达到 55% 精度所需轮数:

  • FedSAM:> 900 round
  • FedNSAM:约 316 round
    ➡️ 收敛速度快约 3 倍

在这里插入图片描述

5.2 Tiny ImageNet + 大模型(ViT / Swin)

在 Tiny ImageNet 上,作者测试了:

  • Swin-Small(50M 参数)
  • Swin-Base(80M)
  • ViT-Base(88M)

数据极端非 IID:Dirichlet-0.1,客户端参与率仅 5%。

结果:

  • Swin-Small:FedNSAM 约 70.12%;
  • Swin-Base:FedNSAM 约 70.86%;
  • ViT-Base:FedNSAM 约 71.23%。

关键是:

  • 精度领先所有对比方法;
  • 轮数更少,训练效率更高。

这说明 FedNSAM 不只是小模型好用,对 大模型 + 高度异质数据 也非常友好。


在这里插入图片描述

5.3 不同参与率 & 异质程度:鲁棒性更强

在 CIFAR-100 + ResNet-18 上,论文进一步考察:

  • 客户端参与率:2%、5%、10%;
  • Dirichlet:0.1、0.3、0.6(数值越小异质性越强)。

观察到:

  • 参与率只有 2% 时,FedNSAM 仍拿到 56.92% 的高精度;

  • 当 Dirichlet 从 0.6 降到 0.1(异质更强):

    • FedSAM 精度从 47.83% → 40.18%,大幅下滑;
    • FedNSAM 从 66.04% → 58.53%,仍远高于其他方法。

➡️ 在极端非 IID + 低参与率 场景下,FedNSAM 的优势尤为明显。


在这里插入图片描述

5.4 损失曲面可视化:FedNSAM 的“谷底更宽”

论文还可视化了各方法训练后模型周围的训练/测试损失曲面:

  • FedSAM、MoFedSAM、FedGAMMA 等在高异质场景下,曲面明显尖锐;
  • FedNSAM 的曲线则更加平滑、低凸起,谷底区域更宽。

这与“平坦最小值 → 泛化更优”的经典经验非常吻合。


在这里插入图片描述

6. 工程视角:如何把 FedNSAM 接到自己的项目里?

✅ 官方代码仓库:https://github.com/junkangLiu0/FedNSAM
建议直接 clone 下来参考实现。

这里给一个整体接入思路(PyTorch / 自建联邦框架都适用):

6.1 服务器端改动

  1. 维护全局动量向量 (\mathbf{m})

    m = torch.zeros_like(global_model_params)  # 初始化
    
  2. 每轮通信后,更新动量 & 全局参数

    # 假设 delta_list = [Δ_i^t] 为各客户端上传的模型差分
    delta_avg = average(delta_list)
    m = momentum * m + delta_avg          # momentum = λ
    global_params = global_params + m
    

6.2 客户端侧本地训练(伪代码)

以一个 step 为例(忽略 batch 等细节):

# θ_k: 当前本地参数
# m: 从服务器下发的全局动量
# λ: 动量系数
# ρ: 扰动半径
# η: 学习率

# 1. Nesterov 预判一步
theta_hat = theta_k + λ * m

# 2. SAM 扰动方向使用 -m
delta = ρ * (-m / (m.norm() + 1e-12))

# 3. 在扰动点上计算梯度
theta_perturbed = theta_hat + delta
loss = local_loss(theta_perturbed, data_batch)
g = autograd(loss, theta_perturbed)

# 4. 用该梯度更新参数(在 θ_k 上更新)
theta_k_plus_1 = theta_k - η * g

和原有 FedAvg / FedSAM 相比:

  • 不改变通信格式(仍然上传模型差分);
  • 客户端只需多接收一个全局动量向量 (m);
  • 本地更新规则略微修改即可。

6.3 超参数推荐(来自论文实验经验)

  • 学习率 (\eta):1e-3 ~ 3e-1 范围内网格搜索;
  • 动量参数 (\lambda):0.85 左右效果稳定;
  • 扰动半径 (\rho):0.1 是一个不错的默认值
  • 本地 epoch 视任务确定,非 IID 程度越高,可适当减少本地步数。

7. 个人小结:FedNSAM 的启发

这篇工作不仅给出了一个效果更强的算法,更重要的是,带来了几个有价值的思路:

  1. 从“点平坦”到“多客户端平坦对齐”

    • 不再只盯某个客户端的 sharpness,而是通过平坦距离 (\Delta_D) 量化“大家的谷底是否靠近”。
  2. 把全局信息当作“对齐器”

    • 全局 Nesterov 动量 (\mathbf{m}_t) 把各客户端的优化方向统一到一个大趋势上,
      避免“各自为战、平均之后掉缝隙”。
  3. 联邦 + 大模型 + 非 IID 是趋势,FedNSAM 很适配

    • 在 Swin / ViT 等大模型上,FedNSAM 显示出非常可观的优势;
    • 对未来大规模隐私保护训练非常有参考价值。

如果你正准备:

  • 在自己项目里实现联邦学习;
  • 或者已经在用 FedAvg / FedProx / FedSAM 想继续提升效果;

💡 不妨尝试把 优化器升级为 FedNSAM,看看在你自己的数据集上,
能带来多大的提升。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值