1.针对的问题
在实际场景中,客户端获取的数据往往没有任何标签
1.针对客户端有标注数据和未标注数据
2.只有服务器有标注的数据
对于数据无标签或少标签的模型训练
即半监督联邦学习的模型训练问题
2.采取的方法
传统的半监督学习是强制 增强样本和原始样本输出相同的类别标签
以提高模型的泛化能力和鲁棒性。通过增强样本,增加噪声从而不依赖某些权重参数。
以此为启发提出客户端之间的一致性正则化方法提高无标签数据的利用率
公式定义如下:
其中 是辅助代理模型的预测分布, 是本地客户端的预测分布
即当KL散度越小,代理模型与客户端模型预测结果越接近,即也就越接近真实值
注:辅助代理是服务器段端预先训练的模型,在进行联邦半监督训练,辅助代理模型并不更新参数,辅助代理的选客户端模型参数相似或者输出分布相似的模型。
一致性正则化公式:
这个公式的作用是,通过损失函数量化确保客户端模型预测结果尽量接近于伪标签,通过KL散度使客户端模型更接近代理模型预测结果,因为代理模型不更新参数,所以有助防止过拟合,即正则化。
以上就是一致性正则化函数
交叉熵损失函数提高精度
平均KL散度以一致性增强数据和原始数据预测结果和代理模型结果比较,防止过拟合。
对于传统的半监督学习 有标签和无标签使用同一套参数
但由于使用同一套参数会出现,无标签数据训练影响之前的有标签数据训练的参数,导致之前的训练参数出现偏移。
本文提出模型参数分为两个参数集合分别对有标签和无标签的数据进行训练,以防止上述影响。
对于有标签的数据仅使用损失函数进行提高模型精度与常规训练一样
对于无标签数据使用上述一致性正则化函数进行训练
对于有标签数据在客户端的情况
算法流程如下
服务器端:
首先初始化参数,随机选定A个客户端进行训练,
(6行)服务器需要找到与客户端模型最相似的参数H辅助训练(即代理模型H)
然后训练客户端模型,并存储训练后的参数 然后求平均值
客户端:
θla← σ + ψ, 代表客户端本地模型参数集合,分为训练有标签数据的参数σ,和训练无标签数据的参数ψ
θh1:H← σ + ψ1:H,代表代理模型参数集合,同样分为两类参数
然后对于不同数据分别训练,*号表示冻结参数。
对于有标签数据在服务器端的情况
即客户端的数据均没有标签的情况
服务器端:
首先初始化参数
对有标签数据进行小批梯度下降,训练全局模型,其中用于训练无标签数据的参数冻结
选择随机的客户端进行训练,并且选择相似的代理模型辅助客户端进行训练
保存参数
平均参数
客户端:
只更新无标签的参数并上传到服务器端
实验的设置
有三种实验的任务类型
-
Batch-IID(批量 IID)
使用CIFAR-10,数据划分为:
训练集 54000张
验证集3000张
测试集3000张
每个客户端随机抽取5个数据作为有标签数据集合,其他作为无标签数据 -
Batch-NonIID
与Batch-IID相同,人为构造客户端数据类别不平衡 -
Streaming-NonIID 即随着时间逐步到达的异构数据
使用Fashion-MNIST(共70000张图像)
数据划分为:
训练集 63000张。
验证集 3500张。
测试集 3500张。
每个客户端随机抽取5张图片作为有标签数据集合,其他作为无标签数据集合
并且仍然是类别不均衡的客户端数据集合。
数据是流式进入客户端,对于客户端的无标签数据集合会进一步划分多个时间步
每个时间步都会训练10轮
实验的结果
通信开销变小的同时准确率并没有低多少