A Communication-Efficient Collaborative Learning Framework for Distributed Features
- 背景与动机:数据孤岛在不同组织中普遍存在,协同学习成为解决数据孤岛和隐私问题的有吸引力的方案。但现有架构在通信敏感场景中未充分解决通信问题,且存在数据泄漏和通信开销昂贵等问题。
- 方法:提出了一种名为 Federated stochastic block coordinate descent(FedBCD)的分布式特征协同学习框架,各参与方仅共享每个样本的单个值,而不是模型参数或原始数据,且能在不进行每轮迭代通信的情况下持续进行本地模型更新。
- 实验与结论:通过理论分析了本地更新次数的影响,证明了当批量大小、样本大小和本地迭代次数选择适当时,该算法在 T 次迭代内执行 O(√T)轮通信,并达到 O(1 / √T)的精度。通过在多种任务和数据集上的实验评估,证明了该方法优于随机梯度下降(SGD)方法,且添加近端项可以进一步增强在 Q 值较大时的收敛性。
其中时间复杂度的T代表算法的迭代次数。
文章的 Introduction 部分主要介绍了协同学习的背景和相关问题
研究场景:现有协同学习框架中数据多按样本分布且共享相同属性,但存在一种跨组织协同学习问题,即各方共享相同用户但具有不同的特征集,例如同一城市的本地银行和零售公司可能在用户基础上有很大重叠,构建协同学习模型将对这些方有益。
现有问题:特征分区的协同学习问题在 DL 和 FL 设置中都有研究,但现有架构未充分解决通信问题,在数据地理分布、数据局部性和隐私至关重要的场景中,这些方法通常需要每轮迭代进行通信和计算,且为防止数据泄漏采用的隐私保护技术会增加昂贵的通信开销,此外,样本分区的 FL 中通过 FedAvg 进行多次本地更新可有效减少通信轮数,但在分布式特征中进行此类本地更新的可行性尚不清楚。
本文工作:提出了名为 Federated stochastic block coordinate descent(FedBCD)的协同学习框架,各方仅在每次通信时共享每个样本的单个值,而非模型参数或原始数据,且能持续进行本地模型更新而无需每轮迭代通信,所有原始数据和模型参数都保持本地,与集中训练的模型相比性能无损失,通过采用 FedBCD 可显著降低通信成本,并通过实验评估了 FedBCD 与其他替代协议的性能,还将算法应用于联邦迁移学习(FTL)以解决标记数据少和用户重叠不足的问题。
Problem Definition
这是一个典型的联合损失函数,包含了预测误差和正则化项。损失函数取决于所有数据方的参数 Θ,并且我们对所有样本 𝐷𝑖 进行了求和。
损失函数 f 通常依赖于各数据方的特征和标签,比如线性回归、逻辑回归等。其中右边括号里表示通过不同数据方的特征和参数的加权和,得到的预测值与标签之间的损失。这个公式的核心思想是,每个数据方 𝑘 有自己的特征 𝑥𝑖𝑘,但他们共享同一个标签 𝑦。各个数据方的特征通过加权求和来生成预测值。
𝑔𝑘(Θ;𝑆):表示相对于参数 𝜃𝑘 的随机梯度(stochastic gradient),通过对小批量数据 𝑆⊆[𝑁] 进行求导得到。 梯度计算通过以下步骤进行:
- 定义 𝐻𝑖𝑘=𝑥𝑖𝑘𝜃𝑘,表示样本特征与参数的加权和。
- 公式(3)这表示对损失函数 𝑓 和正则化项 𝛾 分别对参数 𝜃𝑘 求导,得到的梯度。
- 计算局部梯度 ∇𝑓(Θ;𝑆),即公式(4),这是相对于特征的梯度。
- 将梯度转化为全局损失函数的梯度,公式(5)通过对局部梯度求和,得到整体的梯度更新方向。
- 公式(4)的意思是,我们对样本 𝑖 的损失函数进行求导,得到关于特征 𝑥𝑖𝑘 的梯度,然后对所有样本进行平均。
- 公式(5)- 全局梯度计算: 在计算了局部梯度后,各数据方将其本地的梯度发送给一个中央服务器,该服务器汇总所有数据方的信息,计算全局的梯度
为了保证各数据方可以不共享原始数据而协同计算,需要定义一种信息共享机制,
公式(6)定义了从其他数据方收集的信息: 𝐻𝑆𝑘𝑞={𝐻𝑞𝑘(𝜃𝑞,𝑆𝑞)}𝑞≠𝑘 即从其他数据方收集的梯度信息。 通过公式(7)和(8),我们将各数据方的局部信息进行汇总,得到全局梯度。最后,公式(9)给出了随机梯度下降的更新规则: 𝜃𝑘←𝜃𝑘−&