深入探索联邦学习框架 Flower

联邦学习框架

本文主要期望介绍一个设计良好的联邦学习框架 Flower,在开始介绍 Flower 框架的细节前,先了解下联邦学习框架的基础知识。

作为一个联邦学习框架,必然会包含对横向联邦学习的支持。横向联邦是指拥有类似数据的多方可以在不泄露数据的情况下联合训练出一个模型,这个模型可以充分利用各方的数据,接近将全部数据集中在一起进行训练的效果。横向联邦学习的一般流程如下:

请添加图片描述

横向联邦学习的过程简单理解如下:

  1. 各个参与方基于自身的数据训练出本地模型,将模型参数发送给公共的服务端;
  2. 服务端将收到的多个模型参数聚合生成全局的模型参数,然后下发给各个参与方;
  3. 参与方使用全局的模型参数更新本地模型,重复这个步骤直到模型训练达到要求;

从上面的过程可以看到,作为一个联邦学习框架,需要关注下面要点:

  1. 参与方本地模型训练;
  2. 模型参数的传输;
  3. 模型的聚合策略;

Flower 框架上手

Flower 是一个轻量的联邦学习框架,提出于 2020 年。一直以来,因为设计良好,方便扩展受到了比较多的关注。团队通过论文 FLOWER: A FRIENDLY FEDERATED LEARNING FRAMEWORK 介绍了框架的设计思想。通过论文可以看到框架设计主要追求下面目标:

  1. 可拓展,支持大量的客户端同时进行模型训练;
  2. 使用灵活,支持异构的客户端,通信协议,隐私策略,支持新功能的开销小;

首先基于 Flower 框架实际进行了一个机器学习的模型训练,通过实际动手可以感受下基于 Flower 框架可以用相当简单的方式实现一个联邦学习模型训练流程。这个流程是参考 Flower Quickstart 实现的。

常规机器学习部分实现

首先实现机器学习训练所需的基础方法,主要是数据集的准备,定义所需的模型,封装训练与测试流程,这部分与联邦学习无关。熟悉 pytorch 的应该很容易理解这部分代码:

from collections import OrderedDict

import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm


DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 定义神经网络模型

class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# 定义模型训练流程

def train(net, trainloader, epochs):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    for _ in range(epochs):
        for images, labels in tqdm(trainloader):
            optimizer.zero_grad()
            criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
            optimizer.step()

# 定义模型推理流程

def test(net, testloader):
    criterion = torch.nn.CrossEntropyLoss()
    correct, loss = 0, 0.0
    with torch.no_grad():
        for images, labels in tqdm(testloader)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

易迟

高质量内容创作不易,支持下

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值