联邦学习框架
本文主要期望介绍一个设计良好的联邦学习框架 Flower,在开始介绍 Flower 框架的细节前,先了解下联邦学习框架的基础知识。
作为一个联邦学习框架,必然会包含对横向联邦学习的支持。横向联邦是指拥有类似数据的多方可以在不泄露数据的情况下联合训练出一个模型,这个模型可以充分利用各方的数据,接近将全部数据集中在一起进行训练的效果。横向联邦学习的一般流程如下:
横向联邦学习的过程简单理解如下:
- 各个参与方基于自身的数据训练出本地模型,将模型参数发送给公共的服务端;
- 服务端将收到的多个模型参数聚合生成全局的模型参数,然后下发给各个参与方;
- 参与方使用全局的模型参数更新本地模型,重复这个步骤直到模型训练达到要求;
从上面的过程可以看到,作为一个联邦学习框架,需要关注下面要点:
- 参与方本地模型训练;
- 模型参数的传输;
- 模型的聚合策略;
Flower 框架上手
Flower 是一个轻量的联邦学习框架,提出于 2020 年。一直以来,因为设计良好,方便扩展受到了比较多的关注。团队通过论文 FLOWER: A FRIENDLY FEDERATED LEARNING FRAMEWORK 介绍了框架的设计思想。通过论文可以看到框架设计主要追求下面目标:
- 可拓展,支持大量的客户端同时进行模型训练;
- 使用灵活,支持异构的客户端,通信协议,隐私策略,支持新功能的开销小;
首先基于 Flower 框架实际进行了一个机器学习的模型训练,通过实际动手可以感受下基于 Flower 框架可以用相当简单的方式实现一个联邦学习模型训练流程。这个流程是参考 Flower Quickstart 实现的。
常规机器学习部分实现
首先实现机器学习训练所需的基础方法,主要是数据集的准备,定义所需的模型,封装训练与测试流程,这部分与联邦学习无关。熟悉 pytorch 的应该很容易理解这部分代码:
from collections import OrderedDict
import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 定义神经网络模型
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
# 定义模型训练流程
def train(net, trainloader, epochs):
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for _ in range(epochs):
for images, labels in tqdm(trainloader):
optimizer.zero_grad()
criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
optimizer.step()
# 定义模型推理流程
def test(net, testloader):
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for images, labels in tqdm(testloader)