Consistency of Local and Global Flatness for Federated Learning
【联邦学习也玩“找平坦最小值”?一文读懂 FedNSAM】 Junkang Liu(刘俊康)

论文摘要:
在联邦学习(Federated Learning, FL)中,多步本地更新和数据异质性通常会导致更“尖锐”的全局极小值,从而削弱全局模型的性能。为了解决这一问题,许多主流 FL 算法会在本地训练中引入 sharpness-aware minimization(SAM)。然而,在高数据异质性的场景下,本地训练中得到的“平坦性”并不一定意味着全局模型同样是平坦的。因此,仅仅在各客户端数据上最小化本地损失曲面的锐度,并不能保证 SAM 在 FL 中真正提升全局模型的泛化能力。
为了解释这一现象,我们首先引入了一个名为 平坦距离(flatness distance) 的度量,并对其进行了重新思考与理论分析。在此基础上,我们提出了一种全新的 FedNSAM 算法:通过在本地更新中引入全局 Nesterov 动量(global Nesterov momentum),来加速 SAM,并协调全局与本地平坦性的“一致性”。FedNSAM 使用全局 Nesterov 动量作为客户端对全局扰动进行本地估计以及外推(extrapolation)的方向。
从理论上,我们基于 Nesterov 外推证明了 FedNSAM 相比 FedSAM 具有更紧的收敛上界;从实验上,我们在 CNN 和 Transformer 模型上进行了大量实验,验证了 FedNSAM 在性能和效率上的显著优势。代码已开源,地址为:
https://github.com/junkangLiu0/FedNSAM
一作 Junkang Liu(刘俊康)的公开资料/简介链接。
🔗 Google 学术个人主页
https://scholar.google.com/citations?user=N7pJWIoAAAAJ&hl=zh-CN ([谷歌学术][1])
ℹ️ 关于 Junkang Liu 的简要背景
- 他目前是 Tianjin University(天津大学)的 PhD 学生,同时也曾在 Xidian University(西安电子科技大学)攻读 MS。
- 他的研究方向包括联邦学习(Federated Learning)、优化算法、分布式训练等。([OpenReview][2])
- 除了最近的 FedSWA / FedMoSWA(Improving Generalization in Federated Learning …)论文之外,他也参与其他与 Federated Learning 优化相关的研究,例如 FedBCGD ,FedSWA, FedAdamW,FedNSAM等。
文章目录
- Consistency of Local and Global Flatness for Federated Learning
- 【联邦学习也玩“找平坦最小值”?一文读懂 FedNSAM】 Junkang Liu(刘俊康)
- 🟩 图 2 的核心想表达一句话:
- 🟦 左图 (a):Flatness Distance(平坦距离)
- 🟧 右图 (b):Global Sharpness(全局尖锐度)
- 🟨 文字说明总结含义
🎯 一句话总结:FedNSAM = 联邦学习 + SAM + 全局 Nesterov 动量
让各个客户端不再“各找各的平坦谷底”,而是一起对齐到更“宽、更稳”的全局平坦解。
0. 论文 & 代码地址
- 📄 论文标题:FedNSAM: Sharpness-Aware Minimization in Federated Learning with Distance of Flat Minimum
- 🧪 官方代码仓库:FedNSAM (PyTorch 实现)
👉 GitHub:https://github.com/junkangLiu0/FedNSAM
1. 背景:联邦学习为啥这么“难训稳”?
联邦学习(Federated Learning, FL)已经在医疗、金融、移动端等场景广泛应用:
- 数据保留在本地,不上传原始数据,提升隐私保护;
- 服务器只负责聚合来自各个客户端的模型更新。
然而,真正落地时会遇到几个经典痛点:
-
🔸 数据高度非 IID:每个客户端数据分布不一样,比如:
- 医院 A 以老年人为主;
- 医院 B 以年轻人为主;
-
🔸 本地多步更新(local epochs 多):每个客户端都会按自己数据“疯狂更新”;
-
🔸 全局模型容易掉进“尖锐谷底”:
- 每个客户端都朝着“自己局部的低谷”走;
- 聚合后的全局模型,可能刚好落到“谷与谷之间的缝隙”。
结果就是:训练集 loss 也许不难看,但测试集效果很拉胯,泛化性能差。
2. SAM 回顾:别只看一个点,要看“附近一片区域”
2.1 传统训练:只在意当前点
传统优化目标:
min
θ
F
(
θ
)
\min_{\theta} F(\theta)
θminF(θ)
只关心在某个参数点 (\theta) 上的损失值,不关心它附近区域是不是也好。
这会导致:
- 在训练集上“刚刚好”很低;
- 但稍微扰动一下参数,loss 就爆表;
- 这类解被称为 sharp minima(尖锐最小值),泛化表现往往不好。
2.2 SAM 的核心思想:最坏扰动下也要好
SAM(Sharpness-Aware Minimization)的目标是:
让当前点周围一小片“球形邻域”内的最坏点,也尽量损失不大。
数学形式:
min
θ
max
∣
δ
∣
2
≤
ρ
F
(
θ
+
δ
)
\min_\theta \max_{|\delta|_2 \le \rho} F(\theta + \delta)
θmin∣δ∣2≤ρmaxF(θ+δ)
操作流程(简化版):
- 在当前参数 (\theta) 附近加一个扰动 δ \delta δ, ∣ δ ∣ ≤ ρ |\delta| \le \rho ∣δ∣≤ρ;
- 找到这块邻域里损失最大的点;
- 让这个“最坏点”的损失也尽可能小。
直观理解:
- 如果一个解附近都“比较平”,就算被扰动一下,损失也不会太差;
- 如果是尖锐谷底,一点点动就损失暴涨,SAM 会主动避开这类解。
2.3 FedSAM:把 SAM 搬进联邦学习
FedSAM 的做法:
- 每个客户端本地训练时,用 SAM 做优化;
- 服务器只做普通参数聚合。
在数据异质性不严重的场景下,FedSAM 能显著提升全局模型的泛化性能。
但论文发现:
当数据高度非 IID 时,各客户端找到的“本地平坦最小值”彼此差距很大,
聚合后的全局模型依然不平坦,sharpness 很高,泛化仍然糟糕。
所以作者提出一个新问题:
👉 如何衡量“本地平坦最小值”与“全局模型”的距离?
3. 平坦距离(Flatness Distance):把“几何形状”量化出来
3.1 定义:客户端最小值 vs 全局模型
设第 (t) 轮训练后:
- 客户端 (i) 做完 (K) 次本地更新后的模型为:(\theta_{i,K}^t);
- 服务器聚合后的全局模型为:
θ t + 1 = 1 N ∑ i = 1 N θ i , K t \theta^{t+1} = \frac{1}{N}\sum_{i=1}^N \theta_{i,K}^t θt+1=N1i=1∑Nθi,Kt
论文定义 平坦距离(Flatness Distance)为:
Δ D = 1 N ∑ i = 1 N E ∣ θ i , K t − θ t + 1 ∣ 2 \Delta_D = \frac{1}{N} \sum_{i=1}^N \mathbb{E} |\theta_{i,K}^t - \theta^{t+1}|^2 ΔD=N1i=1∑NE∣θi,Kt−θt+1∣2
直觉理解:
- 每个客户端都有一个“自认为平坦的谷底”;
- 全局模型是这些谷底的“平均值”;
- 如果这些谷底彼此靠得很近,全局模型也会在一个大致平坦的区域;
- 如果它们离得很远,平均后全局模型可能掉在**“谷与谷之间的缝隙里”**,变得很尖锐。

-
上面三张图:
- (a) FedSAM, α=0.6:数据异质性不高时,损失曲面比较平坦,FedSAM 还能找到不错的“宽谷底”。
- (b) FedSAM, α=0.1:数据高度非 IID 时,各客户端学到的东西差太多,聚合后全局模型落在“谷底之间的缝隙”,整体变“尖”,泛化变差。
- © FedNSAM, α=0.1:即使在高度非 IID 下,通过全局 Nesterov 动量对齐方向,FedNSAM 让全局模型重新回到更宽、更稳的谷底。
-
下面三张示意图:
- (d) 低异质:各客户端的平坦区域重叠多,平均后还在平坦区。
- (e) 高异质:平坦区分得很开,平均后掉到平坦区之外。
- (f) FedNSAM 通过修正(全局动量),让客户端的更新方向更一致,最终的全局模型回到“共同的平坦低损失区域”里。
3.2 关键发现:平坦距离越大,全局越“尖”

下面是对 论文图 2(Figure 2) 的简洁清晰解释:
🟩 图 2 的核心想表达一句话:
FedNSAM 能显著减少平坦距离、降低全局 sharpness,因此在高异质数据下比 FedSAM 泛化更好。
🟦 左图 (a):Flatness Distance(平坦距离)
-
曲线越低 → 客户端本地最优点之间越接近 → 全局越容易落在共同的“平坦区”。
-
现象:
- FedSAM(绿、红、蓝三条)随着异质性增大(Dirichlet 0.1 → 0.3 → 0.6)平坦距离更大,下降更慢。
- FedNSAM(黄色)始终最低,下降最快 → 客户端的本地平坦区域被成功“对齐”,一致性更高。
➡️ FedNSAM 明显减少客户端之间的偏移,让平均后的全局模型更稳定。
🟧 右图 (b):Global Sharpness(全局尖锐度)
-
曲线越低 → 全局模型越平坦 → 泛化能力越好。
-
现象:
- FedSAM 在高度异质(Dirichlet 0.1)下 sharpness 最高,曲线下降慢。
- FedNSAM(黄色)全程 sharpness 最低,表示找到的全局解更平坦。
➡️ FedNSAM 有效缓解了由于高度异质造成的“尖锐全局最小值”。
🟨 文字说明总结含义
-
论文使用 CIFAR-100、100 个客户端、10% 参与率。
-
异质性越强(Dirichlet 越小),FedSAM 性能下降越明显:
- 40.18%(Dir=0.1)
- 46.02%(Dir=0.3)
- 47.83%(Dir=0.6)
-
FedNSAM 在 Dir=0.1 时达到 58.53% → 远高于 FedSAM 的 40.18%。
➡️ FedNSAM 在高度非 IID 场景下提升巨大(+18.35 个点)。
图 2 显示:
- 左图:FedNSAM 能让客户端本地模型更一致(平坦距离更低)。
- 右图:FedNSAM 让全局模型更加平坦(更低 sharpness)。
- 这两点共同解释了:FedNSAM 在非 IID 数据下远优于 FedSAM。
如果你愿意,我可以帮你把这段内容整理成一段更适合 CSDN 的可直接粘贴排版版。
论文通过理论分析和实验发现:
- 数据越非 IID, Δ D \Delta_D ΔD 越大;
- Δ D \Delta_D ΔD 越大,全局 sharpness 越高,测试精度越差。
于是问题清晰了:
我们不只要在本地“找平坦”,
还要让各客户端的平坦区域尽量对齐,减小平坦距离 Δ D \Delta_D ΔD。
4. FedNSAM:用全局 Nesterov 动量“对齐平坦区域”


4.1 FedNSAM 的主意一句话
在本地做 SAM 时,引入全局 Nesterov 动量,
把各客户端“找平坦”的方向,对齐成尽量一致的全局趋势。
这样:
- 本地仍然是“平坦意识”的优化(SAM 思路在);
- 但更新方向受全局历史信息引导;
- 聚合后,全局更容易处在大家“共同比较平坦”的区域。

4.2 回顾一下 Nesterov 动量(NAG)
经典 Nesterov 加速梯度(NAG)大致为:
θ
t
+
1
/
2
=
θ
t
−
λ
v
t
v
t
+
1
=
λ
v
t
+
η
∇
F
(
θ
t
+
1
/
2
)
θ
t
+
1
=
θ
t
+
v
t
+
1
\begin{aligned} \theta_{t+1/2} &= \theta_t - \lambda v_t \ v_{t+1} &= \lambda v_t + \eta \nabla F(\theta_{t+1/2}) \ \theta_{t+1} &= \theta_t + v_{t+1} \end{aligned}
θt+1/2=θt−λvt vt+1=λvt+η∇F(θt+1/2) θt+1=θt+vt+1
特点:
- 不是在当前点上算梯度,而是在**“预判一步之后”的点**上算;
- 收敛往往比普通动量更快。
4.3 服务器端:维护全局动量
在 FedNSAM 中,服务器维护一个全局动量向量 (\mathbf{m}_t):
m t = λ m t − 1 + Δ t \mathbf{m}_t = \lambda \mathbf{m}_{t-1} + \Delta_t mt=λmt−1+Δt
其中:
- λ \lambda λ:动量系数(论文实验中约 0.85 效果最好);
- Δ t = 1 S ∑ i ∈ S t Δ i t \Delta_t = \frac{1}{S}\sum_{i \in S_t} \Delta_i^t Δt=S1∑i∈StΔit 是本轮参与客户端的平均更新;
- Δ i t = θ i , K t − θ i , 0 t \Delta_i^t = \theta_{i,K}^t - \theta_{i,0}^t Δit=θi,Kt−θi,0t。
然后服务器更新:
θ t = θ t − 1 + m t \theta_t = \theta_{t-1} + \mathbf{m}_t θt=θt−1+mt
➡️ 全局模型不再只是“简单平均”,而是叠加了历史趋势的加速。
4.4 客户端本地更新:Nesterov + SAM 融合版
在第 (t) 轮,本地第 (k) 步,客户端 (i) 的更新规则简要如下:
- Nesterov 预判一步:
θ i , k + 1 / 4 t = θ i , k t + λ m t \theta_{i,k+1/4}^t = \theta_{i,k}^t + \lambda \mathbf{m}_t θi,k+1/4t=θi,kt+λmt
-
SAM 扰动方向:用全局动量,而不是本地梯度:
δ i , k t = ρ ⋅ − m t ∣ m t ∣ \delta_{i,k}^t = \rho \cdot \frac{-\mathbf{m}_t}{|\mathbf{m}_t|} δi,kt=ρ⋅∣mt∣−mt
-
得到扰动后的参数:
θ i , k + 1 / 2 t = θ i , k + 1 / 4 t + δ i , k t \theta_{i,k+1/2}^t = \theta_{i,k+1/4}^t + \delta_{i,k}^t θi,k+1/2t=θi,k+1/4t+δi,kt
- 在扰动点上计算梯度并更新:
θ i , k + 1 t = θ i , k t − η ∇ F i ( θ i , k + 1 / 2 t ; ζ i ) \theta_{i,k+1}^t = \theta_{i,k}^t - \eta \nabla F_i(\theta_{i,k+1/2}^t; \zeta_i) θi,k+1t=θi,kt−η∇Fi(θi,k+1/2t;ζi)
可以看到:
- 本地 SAM 还在做“最坏扰动下的优化”;
- 但扰动方向和预移动方向都依赖全局动量 m t \mathbf{m}_t mt;
- 各客户端的“平坦区域搜索路径”因此被对齐。
4.5 优势总结

FedNSAM 相比 FedSAM / FedAvg 等方法,主要有三点优势:
-
✅ 更好的理论收敛上界
- 在适当学习率和扰动半径下,FedNSAM 的收敛速度
O ( L F T K S ( 1 − λ ) ) \mathcal{O}\left(\frac{\sqrt{LF}}{\sqrt{TKS(1-\lambda)}}\right) O(TKS(1−λ)LF)
优于 FedSAM。
- 在适当学习率和扰动半径下,FedNSAM 的收敛速度
-
✅ 更小的平坦距离 (\Delta_D)
- 理论上证明在同等条件下,FedNSAM 对数据异质性的敏感度更低;
- 实验上也确实观测到: Δ D \Delta_D ΔD 更小,全局 sharpness 更低。
-
✅ 通信成本几乎不变
- 每个客户端只需上传本地模型差分 Δ i t \Delta_i^t Δit;
- 不像某些方法需要额外上传控制变量 / 二阶信息;
- 非常适合实际工程系统落地。
5. 实验结果:从 CNN 到 Transformer 全面提升

论文在多个数据集与网络上做了系统对比,包括:
- CIFAR-10 / CIFAR-100 / Tiny-ImageNet;
- LeNet-5、VGG-11、ResNet-18;
- Swin-Small / Swin-Base / ViT-Base 等。
5.1 CIFAR-100 + ResNet-18:精度直接“起飞”
设置:
- 100 个客户端,参与率 10%;
- 数据分布 Dirichlet-0.6(中度非 IID);
- 本地 epoch 数 E=5。
结果(Top-1, CIFAR-100 + ResNet-18):
- FedSAM:47.83%
- FedNSAM:66.04% 🚀(+12.21 个百分点)
同时,达到 55% 精度所需轮数:
- FedSAM:> 900 round
- FedNSAM:约 316 round
➡️ 收敛速度快约 3 倍。

5.2 Tiny ImageNet + 大模型(ViT / Swin)
在 Tiny ImageNet 上,作者测试了:
- Swin-Small(50M 参数)
- Swin-Base(80M)
- ViT-Base(88M)
数据极端非 IID:Dirichlet-0.1,客户端参与率仅 5%。
结果:
- Swin-Small:FedNSAM 约 70.12%;
- Swin-Base:FedNSAM 约 70.86%;
- ViT-Base:FedNSAM 约 71.23%。
关键是:
- 精度领先所有对比方法;
- 轮数更少,训练效率更高。
这说明 FedNSAM 不只是小模型好用,对 大模型 + 高度异质数据 也非常友好。

5.3 不同参与率 & 异质程度:鲁棒性更强
在 CIFAR-100 + ResNet-18 上,论文进一步考察:
- 客户端参与率:2%、5%、10%;
- Dirichlet:0.1、0.3、0.6(数值越小异质性越强)。
观察到:
-
参与率只有 2% 时,FedNSAM 仍拿到 56.92% 的高精度;
-
当 Dirichlet 从 0.6 降到 0.1(异质更强):
- FedSAM 精度从 47.83% → 40.18%,大幅下滑;
- FedNSAM 从 66.04% → 58.53%,仍远高于其他方法。
➡️ 在极端非 IID + 低参与率 场景下,FedNSAM 的优势尤为明显。

5.4 损失曲面可视化:FedNSAM 的“谷底更宽”
论文还可视化了各方法训练后模型周围的训练/测试损失曲面:
- FedSAM、MoFedSAM、FedGAMMA 等在高异质场景下,曲面明显尖锐;
- FedNSAM 的曲线则更加平滑、低凸起,谷底区域更宽。
这与“平坦最小值 → 泛化更优”的经典经验非常吻合。

6. 工程视角:如何把 FedNSAM 接到自己的项目里?
✅ 官方代码仓库:https://github.com/junkangLiu0/FedNSAM
建议直接 clone 下来参考实现。
这里给一个整体接入思路(PyTorch / 自建联邦框架都适用):
6.1 服务器端改动
-
维护全局动量向量 (\mathbf{m}):
m = torch.zeros_like(global_model_params) # 初始化 -
每轮通信后,更新动量 & 全局参数:
# 假设 delta_list = [Δ_i^t] 为各客户端上传的模型差分 delta_avg = average(delta_list) m = momentum * m + delta_avg # momentum = λ global_params = global_params + m
6.2 客户端侧本地训练(伪代码)
以一个 step 为例(忽略 batch 等细节):
# θ_k: 当前本地参数
# m: 从服务器下发的全局动量
# λ: 动量系数
# ρ: 扰动半径
# η: 学习率
# 1. Nesterov 预判一步
theta_hat = theta_k + λ * m
# 2. SAM 扰动方向使用 -m
delta = ρ * (-m / (m.norm() + 1e-12))
# 3. 在扰动点上计算梯度
theta_perturbed = theta_hat + delta
loss = local_loss(theta_perturbed, data_batch)
g = autograd(loss, theta_perturbed)
# 4. 用该梯度更新参数(在 θ_k 上更新)
theta_k_plus_1 = theta_k - η * g
和原有 FedAvg / FedSAM 相比:
- 不改变通信格式(仍然上传模型差分);
- 客户端只需多接收一个全局动量向量 (m);
- 本地更新规则略微修改即可。
6.3 超参数推荐(来自论文实验经验)
- 学习率 (\eta):1e-3 ~ 3e-1 范围内网格搜索;
- 动量参数 (\lambda):0.85 左右效果稳定;
- 扰动半径 (\rho):0.1 是一个不错的默认值;
- 本地 epoch 视任务确定,非 IID 程度越高,可适当减少本地步数。
7. 个人小结:FedNSAM 的启发
这篇工作不仅给出了一个效果更强的算法,更重要的是,带来了几个有价值的思路:
-
从“点平坦”到“多客户端平坦对齐”
- 不再只盯某个客户端的 sharpness,而是通过平坦距离 (\Delta_D) 量化“大家的谷底是否靠近”。
-
把全局信息当作“对齐器”
- 全局 Nesterov 动量 (\mathbf{m}_t) 把各客户端的优化方向统一到一个大趋势上,
避免“各自为战、平均之后掉缝隙”。
- 全局 Nesterov 动量 (\mathbf{m}_t) 把各客户端的优化方向统一到一个大趋势上,
-
联邦 + 大模型 + 非 IID 是趋势,FedNSAM 很适配
- 在 Swin / ViT 等大模型上,FedNSAM 显示出非常可观的优势;
- 对未来大规模隐私保护训练非常有参考价值。
如果你正准备:
- 在自己项目里实现联邦学习;
- 或者已经在用 FedAvg / FedProx / FedSAM 想继续提升效果;
💡 不妨尝试把 优化器升级为 FedNSAM,看看在你自己的数据集上,
能带来多大的提升。


&spm=1001.2101.3001.5002&articleId=155324526&d=1&t=3&u=71ba2d3a197d40279f76e734ba7752ff)
1919

被折叠的 条评论
为什么被折叠?



