Tensorflow Federated实现联邦学习中的手写数字图像分类

文章讲述了如何在Python中使用TensorflowFederated进行联邦学习,包括数据预处理、模型构建、FederatedAveraging算法的应用,以及如何通过TensorBoard监控模型在EMNIST数据集上的训练进度。
部署运行你感兴趣的模型镜像

环境配置:

python 3.10

Tensorflow Federated 0.61.0

%load_ext tensorboard
1 导入库
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
from matplotlib import pyplot as plt

np.random.seed(0)

 测试tensorflow federated是否成功导入

tff.federated_computation(lambda: 'Hello, World!')()
2 准备输入数据
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

 

可视化联邦数据的异质性

# 客户端样本每层的示例数
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # 为每个标签单独追加计数以绘制图
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

# 每个客户端都有不同的平均图像,这意味着每个客户端将在本地将模型推向自己的方向。

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

 

  

3 数据集的预处理

在这里,我们将 28x28 图像平展成 784 元素数组,将个别示例进行洗牌,将其组织成批,并将特征从"像素"和"标签"重命名为" x "和" y ",以便与Keras配合使用。

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  def batch_format_fn(element):
    """将一个批次的"像素"进行扁平化处理,并将特征作为" OrderedDict "返回."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)
preprocessed_example_dataset = preprocess(example_dataset)

下面是一个简单的辅助函数,它将从给定的用户集合中构造一个数据集列表,作为一轮训练或评估的输入。

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]
4 选择客户端

在一个典型的联邦训练场景中,我们要处理的是潜在的大量用户设备,其中只有一小部分可以在给定的时间点上进行训练。例如,当客户端设备为手机时,手机只有当插上电源参与训练,手机断开计量网络处于空闲状态。当然,我们是在一个模拟的环境中,所有的数据都是本地可用的。通常情况下,当运行仿真时,我们会简单地在每一轮训练中抽取一个随机的客户子集,一般在每一轮中不同。通过学习 Federation Averaging  ( https://arxiv.org/abs/1602.05629)算法)一文可以发现,在每轮客户子集随机采样的系统中实现收敛需要一段时间,而在本交互教程中运行数百轮是不现实的。

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print(f'Number of client datasets: {len(federated_train_data)}')
print(f'First dataset: {federated_train_data[0]}')
Number of client datasets: 10
First dataset: <_PrefetchDataset element_spec=OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))])>
5 利用Keras创建模型
def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])
def model_fn():
  # 我们必须在这里创建一个新的模型,而不能从外部范围捕获它。TFF将在不同的图环境中对此进行调用。
  keras_model = create_keras_model()
  return tff.learning.models.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
6 在联邦数据上训练模型
training_process = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

调用"初始化"计算来构造服务器状态。

train_state = training_process.initialize()
result = training_process.next(train_state, federated_train_data)
train_state = result.state
train_metrics = result.metrics
print('round  1, metrics={}'.format(train_metrics))
NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  result = training_process.next(train_state, federated_train_data)
  train_state = result.state
  train_metrics = result.metrics
  print('round {:2d}, metrics={}'.format(round_num, train_metrics))
7 在TensorBoard中显示模型指标

接下来,使用Tensorboard将这些联邦计算的度量可视化。

logdir = "./logs/training/"
try:
  tf.io.gfile.rmtree(logdir)  # delete any previous results
except tf.errors.NotFoundError as e:
  pass # Ignore if the directory didn't previously exist.
summary_writer = tf.summary.create_file_writer(logdir)
train_state = training_process.initialize()
!ls {logdir}
%tensorboard --logdir {logdir} --port=0

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

<think>好的,用户需要寻找适用于MNIST数据集的联邦学习训练模型代码示例。首先,我需要确认用户的具体需求。联邦学习涉及多个客户端协作训练模型,同时保持数据本地化,适用于隐私保护场景,比如医疗或金融数据。MNIST是一个手写数字数据集,通常用于图像分类任务,因此用户可能希望有一个结合这两个方面的代码示例。 接下来,我需要回忆相关的联邦学习框架。TensorFlow Federated(TFF)是一个常用的库,专门用于联邦学习的开发和实验。它提供了构建联邦学习模型的工具和示例,可能适合用户的需求。此外,PySyft也是一个选项,但考虑到用户可能更倾向于使用更成熟和文档丰富的库,TFF可能更合适。 然后,我需要查找TFF官方文档或示例代码中是否有MNIST的联邦学习实现。通常,TFF的教程会包含这样的示例,例如使用Federated Averaging算法在MNIST上进行训练。需要确保代码示例包括数据预处理、模型定义、联邦训练过程以及评估步骤。 另外,用户提到的引用[2]是关于联邦学习的论文,强调通信效率和数据隐私。在代码示例中,可能需要指出如何减少通信开销,例如通过多轮本地训练后再进行参数聚合,这可以通过TFF的tff.learning.build_federated_averaging_process实现。 还需要考虑代码的兼容性和环境设置。用户可能需要安装TensorFlowTensorFlow Federated,因此需要提供安装命令,如pip install tensorflow-federated。同时,确保代码在最新版本的库中能够运行,可能需要检查是否有API变更。 数据预处理部分,MNIST数据需要被模拟分配到多个客户端。TFF通常提供工具来将集中式数据集划分为联邦数据集,例如使用tff.simulation.datasets.emnist.load_data(),但MNIST可能需要类似的模拟划分。每个客户端的数据分布可以是独立同分布或非独立同分布,示例中通常使用IID划分以简化。 模型定义方面,需要构建一个适用于MNIST的卷积神经网络(CNN)或简单的多层感知机(MLP)。例如,使用tf.keras.Sequential创建模型,然后通过TFF的tff.learning.from_keras_model将其转换为联邦学习模型。 训练循环部分,联邦学习通常涉及多轮通信,每轮中选择一部分客户端进行本地训练,然后聚合模型参数。示例代码应展示如何配置训练循环,设置超参数如训练轮数、客户端数量、批次大小等。 最后,评估模型性能时,需要在中央服务器或测试集上进行测试,确保模型在未见数据上的表现。同时,可能需要指出如何保存和加载训练好的模型以供后续使用。 总结下来,用户的请求需要提供一个结构清晰的联邦学习代码示例,涵盖数据加载、模型定义、训练过程、评估方法,并确保代码可运行。同时,引用相关的联邦学习原理,如通信效率和数据隐私保护,以增强回答的可信度。可能还需要提醒用户注意依赖库的安装和版本兼容性问题,确保示例能够顺利运行。</think>针对联邦学习在MNIST数据集上的代码实现需求,以下是基于TensorFlow Federated(TFF)的完整代码示例及解析。该方案结合了联邦学习的核心特性(如数据去中心化、参数聚合)和MNIST图像分类任务的特点。 ### 一、联邦学习MNIST训练代码实现 ```python import tensorflow as tf import tensorflow_federated as tff # 1. 数据预处理 def preprocess(dataset, batch_size=32): def batch_format(element): return (tf.reshape(element['pixels'], [-1, 28, 28, 1]), tf.reshape(element['label'], [-1, 1])) return dataset.batch(batch_size).map(batch_format) # 加载联邦MNIST数据集 train_data, test_data = tff.simulation.datasets.emnist.load_data( only_digits=True) # 使用数字分类版本 # 模拟客户端数据分布 sample_clients = train_data.client_ids[:10] # 取前10个客户端作为示例 federated_train_data = [preprocess(train_data.create_tf_dataset_for_client(x)) for x in sample_clients] # 2. 模型定义 def create_keras_model(): return tf.keras.models.Sequential([ tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)), tf.keras.layers.MaxPooling2D((2,2)), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) # 3. 联邦学习模型包装 def model_fn(): keras_model = create_keras_model() return tff.learning.from_keras_model( keras_model, input_spec=federated_train_data[0].element_spec, loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()] ) # 4. 联邦平均算法配置 trainer = tff.learning.build_federated_averaging_process( model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02), server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0)) # 5. 训练循环 state = trainer.initialize() NUM_ROUNDS = 10 for round_num in range(NUM_ROUNDS): state, metrics = trainer.next(state, federated_train_data) print(f'Round {round_num+1}, metrics={metrics}') # 6. 模型评估 evaluation = tff.learning.build_federated_evaluation(model_fn) test_metrics = evaluation(state.model, federated_train_data) print(f'Final metrics: {test_metrics}') ``` ### 二、关键实现解析 1. **数据分片**:通过`tff.simulation.datasets.emnist.load_data()`加载预分割的联邦数据集,每个客户端对应一个独立的数据子集[^2]。此示例使用`only_digits=True`参数选择纯数字分类任务。 2. **模型架构**:采用包含卷积层的CNN结构,相比全连接网络更适合图像特征提取: $$ \text{Conv2D}(32,3\times3) \rightarrow \text{MaxPooling} \rightarrow \text{Dense}(128) \rightarrow \text{Dense}(10) $$ 3. **联邦平均算法**:使用`build_federated_averaging_process`实现参数聚合,其核心公式为: $$ w_{t+1} = w_t + \eta_{server} \cdot \frac{1}{K}\sum_{k=1}^K (w_t^k - w_t) $$ 其中$w_t$为全局模型参数,$w_t^k$为客户端k的本地更新参数。 4. **隐私保护**:通过保持数据本地化,避免原始数据传输,符合联邦学习的核心设计原则[^2]。 ###
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值