翻译自Quickstart TensorFlow - Flower 1.5.0
Quickstart TensorFlow
用不到20行的代码构建一个联邦学习系统
在import Flower框架前,需要先安装它:
$ pip install flwr
因为我们想使用TensorFlow (TF)的Keras API,我们也必须安装TF:
$ pip install tensorflow
Flower Client
接下来,在名为client.py的文件中,导入Flower和TensorFlow
import flwr as fl
import tensorflow as tf
我们使用TF的Keras实用程序来加载CIFAR10,这是一个流行的用于机器学习的彩色图像分类数据集。调用tf.keras.datasets.cifar10.load_data()来下载CIFAR10,在本地缓存它,然后将整个训练集和测试集作为NumPy数组返回。
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
接下来,我们需要一个模型。为了本教程的目的,我们使用带有10个输出类的MobilNetV2
model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
Flower server通过一个名为Client的接口与client交互。当server选择一个特定的client端进行训练时,它会通过网络发送训练指令。client接收这些指令并调用client方法之一来运行你的代码(即,训练我们之前定义的神经网络)。
Flower提供了一个名为NumPyClient的方便类,当您的工作负载使用Keras时,它使实现Client接口变得更容易。NumPyClient接口定义了三个方法,可以通过以下方式实现:
class CifarClient(fl.client.NumPyClient):
def get_parameters(self, config):
return model.get_weights()def fit(self, parameters, config):
model.set_weights(parameters)
model.fit(x_train, y_train, epochs=1, batch_size=32, steps_per_epoch=3)
return model.get_weights(), len(x_train), {}def evaluate(self, parameters, config):
model.set_weights(parameters)
loss, accuracy = model.evaluate(x_test, y_test)
return loss, len(x_test), {"accuracy": float(accuracy)}
现在我们可以创建一个类CifarClient的实例,并添加一行来实际运行这个客户端:
fl.client.start_numpy_client(server_address="[::]:8080", client=CifarClient())
对客户端来说就这样了。我们只需要实现Client或NumPyClient并调用fl.client.start_client()或fl.client.start_numpy_client()。字符串“[::]:8080”告诉客户端要连接到哪个服务器。在我们的示例中,我们可以在同一台机器上运行服务器和客户机,因此我们使用“[::]:8080”。如果我们运行一个真正的联邦工作负载,服务器和客户机在不同的机器上运行,那么需要更改的只是我们指向客户机的server_address。
Flower Server
对于简单的工作负载,我们可以启动一个Flower server,并保留所有配置可能性的默认值。在一个名为server.py的文件中,导入Flower并启动server
import flwr as fl
fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3))
Train the model, federated!
客户机和服务器都准备好了,我们现在可以运行所有内容,并看到联邦学习的实际效果。FL系统通常有一个服务器和多个客户端。因此,我们必须首先启动服务器:
$ python server.py
一旦服务器开始运行,我们就可以在不同的终端上启动客户端。打开一个新终端并启动第一个客户端:
$ python client.py
打开另一个终端,启动第二个客户端:
$ python client.py
每个客户端都有自己的数据集。
现在,您应该看到在第一个终端(启动服务器的终端)中训练是如何进行的:
INFO flower 2021-02-25 14:15:46,741 | app.py:76 | Flower server running (insecure, 3 rounds)
INFO flower 2021-02-25 14:15:46,742 | server.py:72 | Getting initial parameters
INFO flower 2021-02-25 14:16:01,770 | server.py:74 | Evaluating initial parameters
INFO flower 2021-02-25 14:16:01,770 | server.py:87 | [TIME] FL starting
DEBUG flower 2021-02-25 14:16:12,341 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:21:17,235 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:21:17,512 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:21:29,628 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-02-25 14:21:29,696 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:25:59,917 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:26:00,227 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:26:11,457 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-02-25 14:26:11,530 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:30:43,389 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:30:43,630 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:30:53,384 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-02-25 14:30:53,384 | server.py:122 | [TIME] FL finished in 891.6143046000007
INFO flower 2021-02-25 14:30:53,385 | app.py:109 | app_fit: losses_distributed [(1, 2.3196680545806885), (2, 2.3202896118164062), (3, 2.1818180084228516)]
INFO flower 2021-02-25 14:30:53,385 | app.py:110 | app_fit: accuracies_distributed []
INFO flower 2021-02-25 14:30:53,385 | app.py:111 | app_fit: losses_centralized []
INFO flower 2021-02-25 14:30:53,385 | app.py:112 | app_fit: accuracies_centralized []
DEBUG flower 2021-02-25 14:30:53,442 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:31:02,848 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-02-25 14:31:02,848 | app.py:121 | app_evaluate: federated loss: 2.1818180084228516
INFO flower 2021-02-25 14:31:02,848 | app.py:125 | app_evaluate: results [('ipv4:127.0.0.1:57158', EvaluateRes(loss=2.1818180084228516, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.21610000729560852})), ('ipv4:127.0.0.1:57160', EvaluateRes(loss=2.1818180084228516, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.21610000729560852}))]
INFO flower 2021-02-25 14:31:02,848 | app.py:127 | app_evaluate: failures [] flower 2020-07-15 10:07:56,396 | app.py:77 | app_evaluate: failures []
恭喜你!您已经成功构建并运行了第一个联邦学习系统。完整的源代码可以在github.com中找到。