MLSys 2020 | FedProx:异质网络的联邦优化

本文探讨了FedProx算法如何在FedAvg基础上解决设备计算能力差异和数据分布不均问题。FedProx引入近端项限制模型偏移,并允许不同设备训练不精确解,从而加速在异质环境下联邦优化的收敛。实验显示,适当调整μ参数能有效改善模型性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

在这里插入图片描述
题目: Federated Optimization for Heterogeneous Networks
会议: Conference on Machine Learning and Systems 2020
论文地址:Federated Optimization for Heterogeneous Networks

FedAvg对设备异质性和数据异质性没有太好的解决办法,FedProx在FedAvg的基础上做出了一些改进来尝试缓解这两个问题。

在Online Learning中,为了防止模型根据新到来的数据进行更新后偏离原来的模型太远,也就是为了防止过调节,通常会加入一个余项来限制更新前后模型参数的差异。FedProx中同样引入了一个余项,作用类似。

I. FedAvg

Google的团队首次提出了联邦学习,并引入了联邦学习的基本算法FedAvg。问题的一般形式:
在这里插入图片描述
公式1: f i ( w ) = l ( x i , y i ; w ) f_i(w)=l(x_i,y_i;w) fi(w)=l(xi,yi;w)表示第 i i i个样本的损失,即最小化所有样本的平均损失。

公式2: F k ( w ) F_k(w) Fk(w)表示一个客户端内所有数据的平均损失, f ( w ) f(w) f(w)表示当前参数下所有客户端的加权平均损失。

值得注意的是,如果所有 P k P_k Pk(第k个客户端的数据)都是通过随机均匀地将训练样本分布在客户端上来形成的,那么每一个 F k ( w ) F_k(w) Fk(w)的期望都为 f ( w ) f(w) f(w)。这通常是由分布式优化算法做出的IID假设:即每一个客户端的数据相互之间都是独立同分布的。

FedAvg:
在这里插入图片描述
简单来说,在FedAvg的框架下:每一轮通信中,服务器分发全局参数到各个客户端,各个客户端利用本地数据训练相同的epoch,然后再将梯度上传到服务器进行聚合形成更新后的参数。

FedAvg存在着两个缺陷:

  1. 设备异质性:不同的设备间的通信和计算能力是有差异的。在FedAvg中,被选中的客户端在本地都训练相同的epoch,虽然作者指出提升epoch可以有效减小通信成本,但较大的epoch下,可能会有很多设备无法按时完成训练。无论是直接drop掉这部分客户端的模型还是直接利用这部分未完成的模型来进行聚合,都将对最终模型的收敛造成不好的影响。
  2. 数据异质性:不同设备中数据可能是非独立同分布的。如果数据是独立同分布的,那么本地模型训练较多的epoch会加快全局模型的收敛;如果不是独立同分布的,不同设备在利用非IID的本地数据进行训练并且训练轮数较大时,本地模型将会偏离初始的全局模型。

II. FedProx

为了缓解上述两个问题,本文作者提出了一个新的联邦学习框架FedProx。FedProx能够很好地处理异质性。

定义一:
在这里插入图片描述
所谓 γ \gamma γ inexact solution:对于一个待优化的目标函数 h ( w ; w 0 ) h(w;w_0) h(w;w0),如果有:
∣ ∣ ∇ h ( w ∗ ; w 0 ) ∣ ∣ ≤ γ ∣ ∣ ∇ h ( w 0 ; w 0 ) ∣ ∣ ||\nabla h(w^*;w_0)|| \leq \gamma ||\nabla h(w_0;w_0)|| ∣∣∇h(w;w0)∣∣γ∣∣∇h(w0;w0)∣∣
这里 γ ∈ [ 0 , 1 ] \gamma \in [0,1] γ[0,1],我们就说 w ∗ w^* w h h h的一个 γ − \gamma- γ不精确解。

对于这个定义,我们可以理解为:梯度越小越精确,因为梯度越大,就需要更多的时间去收敛。那么很显然, γ \gamma γ越小,解 w ∗ w^* w越精确

我们知道,在FedAvg中,设备 k k k在本地训练时,需要最小化的目标函数为:
F k ( w ) = 1 n k ∑ i ∈ P k f i ( w ) F_k(w)=\frac{1}{n_k}\sum_{i \in P_k}f_i(w) Fk(w)=nk1iPkfi(w)
简单来说,每个客户端都是优化所有样本的损失和,这个是正常的思路,让全局模型在本地数据集上表现更好。

但如果设备间的数据是异质的,每个客户端优化之后得到的模型就与初始时服务器分配的全局模型相差过大,本地模型将会偏离初始的全局模型,这将减缓全局模型的收敛。

为了有效限制这种偏差,本文作者提出,设备 k k k在本地进行训练时,需要最小化以下目标函数:
h k ( w ; w t ) = F k ( w ) + μ 2 ∣ ∣ w − w t ∣ ∣ 2 h_k(w;w^t)=F_k(w)+\frac{\mu}{2}||w-w^t||^2 hk(w;wt)=Fk(w)+2μ∣∣wwt2
作者在FedAvg损失函数的基础上,引入了一个proximal term,我们可以称之为近端项。引入近端项后,客户端在本地训练后得到的模型参数 w w w将不会与初始时的服务器参数 w t w^t wt偏离太多。

观察上式可以发现,当 μ = 0 \mu=0 μ=0时,FedProx客户端的优化目标就与FedAvg一致。

这个思路其实还是很常见的,在机器学习中,为了防止过调节,亦或者为了限制参数变化,通常都会在原有损失函数的基础上加上这样一个类似的项。比如在在线学习中,我们就可以添加此项来限制更新前后模型参数的差异。

FedProx的算法伪代码:
在这里插入图片描述
输入:客户端总数 K K K、通信轮数 T T T μ \mu μ γ \gamma γ、服务器初始化参数 w 0 w^0 w0,被选中的客户端的个数 N N N,第 k k k个客户端被选中的概率 p k p_k pk

对每一轮通信:

  1. 服务器首先根据概率 p k p_k pk随机选出一批客户端,它们的集合为 S t S_t St
  2. 服务器将当前参数 w t w^t wt发送给被选中的客户端。
  3. 每一个被选中的客户端需要寻找一个 w k t + 1 w_k^{t+1} wkt+1,这里的 w k t + 1 w_k^{t+1} wkt+1不再是FedAvg中根据本地数据SGD优化得到的,而是优化 h k ( w ; w t ) h_k(w;w^t) hk(w;wt)后得到的 γ − \gamma- γ不精确解。
  4. 每个客户端将得到的不精确解传递回服务器,服务器聚合这些参数得到下一轮初始参数。

通过观察这个步骤可以发现,FedProx在FedAvg上做了两点改进:

  1. 引入了近端项,限制了因为数据异质性导致的模型偏离。
  2. 引入了不精确解,各个客户端不再需要训练相同的轮数,只需要得到一个不精确解,这有效缓解了某些设备的计算压力。

III. 实验

图1给出了数据异质性对模型收敛的影响:
在这里插入图片描述
上图给出了损失随着通信轮数增加的变化情况,数据的异质性从左到右依次增加,其中 μ = 0 \mu=0 μ=0表示FedAvg。可以发现,数据间异质性越强,收敛越慢,但如果我们让 μ > 0 \mu>0 μ>0,将有效缓解这一情况,也就是模型将更快收敛

图2:
在这里插入图片描述
左图:E增加后对 μ = 0 \mu=0 μ=0情况的影响。可以发现,太多的本地训练将导致本地模型偏离全局模型,全局模型收敛变缓

中图:同一数据集,增加 μ \mu μ后,收敛将加快,因为这有效缓解了模型的偏移,从而使FedProx的性能较少依赖于 E E E

作者给出了一个trick:在实践中, μ \mu μ可以根据模型当前的性能自适应地选择。比较简单的做法是当损失增加时增加 μ \mu μ,当损失减少时减少 μ \mu μ

但是对于 γ \gamma γ,作者貌似没有具体说明怎么选择,只能去GitHub上研究一下源码再给出解释了。

IV. 总结

数据和设备的异质性对传统的FedAvg算法提出了挑战,本文作者在FedAvg的基础上提出了FedProx,FedProx相比于FedAvg主要有以下两点不同:

  1. 考虑了不同设备通信和计算能力的差异,并引入了不精确解,不同设备不需要训练相同的轮数,只需要得到一个不精确解即可。
  2. 引入了近端项,在数据异质的情况下,限制了本地训练时模型对全局模型的偏离。
### 联邦学习中的数据异构性研究 联邦学习是一种分布式机器学习技术,能够在保护用户隐私的同时联合多个节点进行模型训练[^3]。然而,由于参与训练的节点可能来自不同的地理位置、组织或应用场景,其数据来源和特征分布往往存在显著差异,这种现象被称为数据异构性(Statistical Heterogeneity)。数据异构性可能导致全局模型性能下降甚至发生偏差。 #### 数据异构性的定义与影响 数据异构性是指在联邦学习环境中,各客户端的数据分布不符合独立同分布假设(i.i.d. assumption),即不同客户端之间的数据可能存在类别不平衡、样本重叠度低等问题[^1]。这些特性会对联邦学习的效果造成负面影响,具体表现为: - **模型收敛速度变慢**:传统联邦平均算法FedAvg)在面对非独立同分布(Non-IID)数据时,可能会因为局部最优解的存在而导致整体收敛效率降低[^4]。 - **模型泛化能力减弱**:如果某些类别的数据仅存在于部分客户端,则全局模型对该类别的预测准确性可能大幅下降[^2]。 #### 解决数据异构性的方法概述 为了缓解数据异构性带来的问题,研究人员提出了多种改进策略,主要包括以下几个方面: 1. **优化本地更新过程** 部分工作通过调整本地目标函数的形式来增强对 Non-IID 数据的适应能力。例如,FedProx 方法引入了一个额外的正则化项,用于约束本地模型参数偏离全局模型的程度,从而减少因数据分布差异引起的震荡效应[^4]。 2. **设计新型聚合机制** 另一种思路是从全局视角出发重新定义模型融合方式。Virtual Homogeneity Learning 是一项最新研究成果,它尝试构建虚拟均质空间使得原本高度异质化的个体贡献变得相对一致,最终达到改善测试表现的目的[^5]。 3. **采用元学习框架** 利用元学习的思想可以快速适配新任务并捕捉跨域规律,这对于应对复杂多样的实际场景尤为适用。尽管目前该方向尚处于探索阶段,但仍展现出巨大潜力值得进一步挖掘。 以下是几篇针对联邦学习中数据异构性展开深入探讨的经典文献推荐供参考: ```plaintext [1] McMahan H B, Moore E, Ramage D, et al. Communication-efficient learning of deep networks from decentralized data[J]. arXiv preprint arXiv:1602.05629, 2016. [2] Li T, Sahu A K, Zaheer M, et al. Federated optimization in heterogeneous networks[M]//Proceedings of MLSys'20 Workshop on Distributed and Private Machine Learning. 2020. [3] Smith V, Chiang M, Sanjabi M, et al. Federated multi-task learning[C]//Advances in Neural Information Processing Systems. 2017. [4] Reddi S J, Charles Z, Zaheer M, et al. On the convergence of fedavg on non-iid data[C]//International Conference on Learning Representations. 2020. [5] Tang Z, Zhang Y, Shi S, et al. Virtual homogeneity learning: Defending against data heterogeneity in federated learning[C]//International Conference on Machine Learning. PMLR, 2022: 21111-21132. ``` ### 示例代码片段展示如何模拟 Non-IID 数据环境 下面提供了一段 Python 实现代码用来生成简单的分类任务下具有不同程度标签分离特性的合成数据集。 ```python import numpy as np from sklearn.datasets import make_classification def generate_non_iid_data(num_clients=10, num_classes=5, alpha=0.5): """ Generate synthetic classification dataset with varying levels of label skew. Parameters: num_clients (int): Number of clients to distribute data among. num_classes (int): Total number of classes present within entire population. alpha (float): Dirichlet distribution concentration parameter controlling degree of imbalance between partitions. Returns: dict: Dictionary mapping each client ID to their respective training samples & labels. """ X, y = make_classification(n_samples=500 * num_clients, n_features=20, n_informative=num_classes, n_redundant=0, n_clusters_per_class=1, weights=None, flip_y=0., class_sep=1.) # Split global indices into per-client shards via dirichlet sampling over one-hot encoded targets proportions = [] for i in range(num_clients): proportion = np.random.dirichlet(np.repeat(alpha, num_classes)) proportions.append(proportion) cumulative_proportions = np.cumsum([p / sum(p) for p in proportions], axis=-1).tolist() partitioned_indices = [[] for _ in range(num_clients)] sorted_idx = np.argsort(y) current_client_id = 0 start_index = 0 while True: end_index = next((idx for idx,val in enumerate(cumulative_proportions[current_client_id]) if val >= ((sorted_idx-start_index)/len(sorted_idx))[start_index]), None) if not end_index: break selected_sample_ids = list(range(start_index,end_index)) partitioned_indices[current_client_id].extend(selected_sample_ids) current_client_id +=1 if current_client_id>=num_clients: break start_index=end_index result_dict={f"Client_{cid}": {"data":X[idx],"labels":y[idx]} for cid,idx in zip(range(len(partitioned_indices)),partitioned_indices)} return result_dict if __name__ == "__main__": generated_datasets = generate_non_iid_data() print({k:v["labels"].shape for k,v in generated_datasets.items()}) ```
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Cyril_KI

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值