文献:https://arxiv.org/abs/1602.05629
代码来源:https://github.com/shaoxiongji/federated-learning
参考文章:FedAvg代码详解-优快云博客
目录
1、Utils
options.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import argparse # 用于命令行选项,参数和子命令的解释
def args_parser():
parser = argparse.ArgumentParser()
# --epochs是参数名称
# type是参数类型,从命令行输入的参数默认是字符串类型
# default若参数不输入,默认使用该值
# 这里进行了三类参数的设置,分别是联邦参数,模型参数,其他参数。
# federated arguments
# 联邦参数:
# epochs:训练轮数;
# num_users:用户数量;
# frac:用户选取比例;
# local_ep:本地训练轮数
# local_bs:本地训练批大小
# bs:测试批大小
# lr:学习率
# momentum:SGD动量
# split:测试集划分类型,是用户还是样本
parser.add_argument('--epochs', type=int, default=10, help="rounds of training")
parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C")
parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")
parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
parser.add_argument('--bs', type=int, default=128, help="test batch size")
parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")
parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample")
# model arguments
# model 模型名称
# kernel_num 卷积核数量
# kernel_size 卷积核大小
# norm 归一化方式
# num_filters 过滤器数量
# max_pool 最大池化
parser.add_argument('--model', type=str, default='mlp', help='model name')
parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
help='comma-separated kernel size to use for convolution')
parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets")
parser.add_argument('--max_pool', type=str, default='True',
help="Whether use max pooling rather than strided convolutions")
# other arguments
# dataset 数据集选择
# iid 独立同分布默认
# num_classes 分类数量
# num_channels 图像通道数
# gpu 默认使用
# stopping_rounds 停止轮数
# verbose 日志
# seed 随机数种子
parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
# 命令行中出现了--iid选项,则该选项的值被设置为True,否则为False。
parser.add_argument('--iid', action='store_true', help='whether i.i.d or not')
parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges")
parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
# 命令行中出现了--verbose选项,则该选项的值被设置为True,即显示日志,否则为False。
parser.add_argument('--verbose', action='store_true', help='verbose print')
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
# 出现--all_clients 则表示全部客户端都参与
parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')
args = parser.parse_args()
# 进行参数解析,可以使用args.epochs调用该值
# 使用命令行运行 如 python test.py --epochs 100
return args
sampling.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import numpy as np
from torchvision import datasets, transforms
def mnist_iid(dataset, num_users):
"""
Sample I.I.D. client data from MNIST dataset
:param dataset: 数据集
:param num_users: 用户数量
:return: dict of image index 返回一个字典,键为用户编号,值为分配给用户的样本索引集合
使得每个用户获得相同数量的随机样本
"""
num_items = int(len(dataset)/num_users) #计算每个用户获得的样本数量
dict_users, all_idxs = {}, [i for i in range(len(dataset))] #创建一个空字典和列表,从0到样本数减1的整数依次放入列表中。
for i in range(num_users):
# np.random.choice函数从all_idxs列表中随机选择num_items个元素,且不允许重复选择(replace = False),将结果赋给第i个键对应的键值
dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
all_idxs = list(set(all_idxs) - dict_users[i]) #从全部的索引中删除已经分配的索引
return dict_users
def mnist_noniid(dataset, num_users):
"""
Sample non-I.I.D client data from MNIST dataset
:param dataset:
:param num_users:
:return:
"""
# num_shards 表示分片数量,num_imgs 表示每个分片中的图像数量,一共有60000个训练图片
num_shards, num_imgs = 200, 300
idx_shard = [i for i in range(num_shards)] # 将0-199索引存入idx_shard
dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} #创建字典,包含num_users个键,每个键对应一个空的int64类型的Numpy数组
idxs = np.arange(num_shards*num_imgs) # 是一个包含所有样本索引的一维数组,范围从0到num_shards*num_imgs-1。
labels = dataset.train_labels.numpy() # 提取数据集中的训练标签,并将其转换为NumPy数组
# sort labels
idxs_labels = np.vstack((idxs, labels)) # 将索引和标签进行堆叠形成一个二维数组,一行索引,一行标签
idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] # 根据标签行的值对样本索引和标签进行重新排序,从小到大
idxs = idxs_labels[0,:] # 将排序好的索引赋给idxs
# divide and assign
for i in range(num_users):
rand_set = set(np.random.choice(idx_shard, 2, replace=False)) #从200个索引中随机选择两个放在rand_set中
idx_shard = list(set(idx_shard) - rand_set) # 去除已经选择的索引
for rand in rand_set:
# 将随机的两片rand索引与字典中的值进行连接,并赋值给字典
dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
return dict_users
def cifar_iid(dataset, num_users):
"""
Sample I.I.D. client data from CIFAR10 dataset
:param dataset:
:param num_users:
:return: dict of image index
"""
num_items = int(len(dataset)/num_users)
dict_users, all_idxs = {}, [i for i in range(len(dataset))]
for i in range(num_users):
dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
all_idxs = list(set(all_idxs) - dict_users[i])
return dict_users
if __name__ == '__main__': # 表示如果变量等于'__main__'则表示直接运行当前脚本
# 创建 MNIST 数据集的训练集实例,并对图像数据进行预处理,transforms.Compose 是一种组合多个图像预处理操作的方法
dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),# 张量转换
transforms.Normalize((0.1307,), (0.3081,)) # 图像归一化处理
]))
num = 100
d = mnist_noniid(dataset_train, num)
2、models
Nets.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import torch
from torch import nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, dim_in, dim_hidden, dim_out): # 输入维度,隐藏层维度,输出维度
super(MLP, self).__init__()
self.layer_input = nn.Linear(dim_in, dim_hidden) # 线性层,将输入维度映射到输出维度
self.relu = nn.ReLU() # relu激活函数,进行非线性变换
self.dropout = nn.Dropout() # 随机丢弃神经元,防止过拟合
self.layer_hidden = nn.Linear(dim_hidden, dim_out) # 线性层,调整输出维度
def forward(self, x):
x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1]) # 将张量x从任意形状的多维张量展平为一个二维张量,第一维度是根据-1自动推断的。
x = self.layer_input(x)
x = self.dropout(x)
x = self.relu(x)
x = self.layer_hidden(x)
return x
class CNNMnist(nn.Module):
def __init__(self, args): # 接收一个参数
super(CNNMnist, self).__init__()
self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5) # 两次卷积操作
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d() # 二维dropout层,用于随机丢弃部分特征图
self.fc1 = nn.Linear(320, 50) # 全连接层,320维到50维
self.fc2 = nn.Linear(50, args.num_classes) # 50维映射到预测的类别数量
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2)) # 经过卷积,池化,relu激活函数
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return x
class CNNCifar(nn.Module):
def __init__(self, args):
super(CNNCifar, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, args.num_classes)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
Update.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import torch
from torch import nn, autograd
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from sklearn import metrics
class DatasetSplit(Dataset): # 构建数据集
def __init__(self, dataset, idxs): # 接收数据集以及索引
self.dataset = dataset
self.idxs = list(idxs)
def __len__(self): # 返回构建的数据集大小
return len(self.idxs)
def __getitem__(self, item): # 返回索引为self.idxs[item]处的图像和标签数据
image, label = self.dataset[self.idxs[item]]
return image, label
class LocalUpdate(object): # 本地更新模型构建模块
def __init__(self, args, dataset=None, idxs=None):# args 是一些训练参数 dataset 是整个数据集 idxs 是当前客户端用于训练的样本索引列表。
self.args = args
self.loss_func = nn.CrossEntropyLoss() # 交叉熵损失函数创建
self.selected_clients = [] # 空的客户端选择列表
# DatasetSplit(dataset, idxs) 创建一个只包含idxs索引对应的子集数据的数据集对象 用DataLoader进行数据加载,设置批大小,以及每轮训练数据打乱
self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)
def train(self, net): # 本地模型训练,接收一个网络模型
net.train() # 设置为训练模式
# train and update
# 设置一个梯度下降优化器,用lr和momentum进行优化,即学习率和动量
optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
epoch_loss = [] # 存储每个迭代周期的损失值
for iter in range(self.args.local_ep): # 迭代本地epoch训练
batch_loss = [] # 存储每个批次的损失值
# 通过 enumerate(self.ldr_train) 遍历数据加载器,获取每个批次的图像数据 images 和标签 labels。
for batch_idx, (images, labels) in enumerate(self.ldr_train):
# 将数据加载到指定的设备上,通常是将数据移动到GPU上进行加速计算,设备由device决定
images, labels = images.to(self.args.device), labels.to(self.args.device)
net.zero_grad() # 将网络的梯度清零
log_probs = net(images) # 通过向前传播计算网络模型对图像的预测值
loss = self.loss_func(log_probs, labels) # 利用损失函数计算损失值
loss.backward() # 反向传播计算梯度
optimizer.step() # 优化器更新网络参数
if self.args.verbose and batch_idx % 10 == 0: # 是否打印日志,控制打印频率
print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
iter, batch_idx * len(images), len(self.ldr_train.dataset),
100. * batch_idx / len(self.ldr_train), loss.item()))
batch_loss.append(loss.item()) # 将损失值 loss.item() 添加到 batch_loss 列表中
# 在每次迭代结束后,计算当前迭代周期中的平均损失值 sum(batch_loss)/len(batch_loss),并将其添加到 epoch_loss 列表中
epoch_loss.append(sum(batch_loss)/len(batch_loss))
return net.state_dict(), sum(epoch_loss) / len(epoch_loss) # 返回网络的状态以及本地epoch_loss平均值
Fed.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import copy
import torch
from torch import nn
def FedAvg(w): # w为多个模型参数
# 生成了一个新的独立的对象w_avg,使得w_avg与w[0]的值相同但是不共享内存。这种操作常用于避免在后续操作中对w_avg的修改影响到原始的w[0]
# w[0]是其中一组模型参数的索引
w_avg = copy.deepcopy(w[0])
for k in w_avg.keys(): # 遍历所有的键
for i in range(1, len(w)): # 遍历所有的模型参数
w_avg[k] += w[i][k]
w_avg[k] = torch.div(w_avg[k], len(w)) # 求均值
return w_avg
test.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @python: 3.6
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
def test_img(net_g, datatest, args): # 用于评估训练好的模型在测试集上的性能 三个参数,模型,数据集,其它参数
net_g.eval() # 将模型设置为评估模式,这意味着在推理阶段不会进行梯度计算。
# testing
test_loss = 0 # 初始化损失
correct = 0 # 正确分类的数量
data_loader = DataLoader(datatest, batch_size=args.bs) # 数据载入
l = len(data_loader)
for idx, (data, target) in enumerate(data_loader): # 在每次循环中,从数据加载器中获取一批测试样本和对应的标签
if args.gpu != -1: # 用GPU计算
data, target = data.cuda(), target.cuda()
log_probs = net_g(data) # data传入模型,得到预测输出
# sum up batch loss 计算当前批次的损失
test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
# get the index of the max log-probability 通过这段代码,可以得到log_probs张量中每一行的最大值对应的索引,即预测结果y_pred
y_pred = log_probs.data.max(1, keepdim=True)[1]
# 比较y_pred和target是否相等,返回一个布尔张量,long转换成长征型,移到CPU上,累加
# 统计预测值y_pred和目标值target之间匹配正确的数量,并将这个数量累加到变量correct中
correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
test_loss /= len(data_loader.dataset) # 计算平均损失值
accuracy = 100.00 * correct / len(data_loader.dataset) # 准确率
if args.verbose: # 打印
print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
test_loss, correct, len(data_loader.dataset), accuracy))
return accuracy, test_loss
main_fed.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch
from utils.sampling import mnist_iid, mnist_noniid, cifar_iid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg
from models.test import test_img
if __name__ == '__main__':
# parse args
args = args_parser()
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
# load dataset and split users
if args.dataset == 'mnist':
trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
# sample users
if args.iid:
dict_users = mnist_iid(dataset_train, args.num_users)
else:
dict_users = mnist_noniid(dataset_train, args.num_users)
elif args.dataset == 'cifar':
trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)
if args.iid:
dict_users = cifar_iid(dataset_train, args.num_users)
else:
exit('Error: only consider IID setting in CIFAR10')
else:
exit('Error: unrecognized dataset')
img_size = dataset_train[0][0].shape
# build model
if args.model == 'cnn' and args.dataset == 'cifar':
net_glob = CNNCifar(args=args).to(args.device)
elif args.model == 'cnn' and args.dataset == 'mnist':
net_glob = CNNMnist(args=args).to(args.device)
elif args.model == 'mlp':
len_in = 1
for x in img_size:
len_in *= x
net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
else:
exit('Error: unrecognized model')
print(net_glob)
net_glob.train() # 模型设置为训练模式
# copy weights 复制模型权重
w_glob = net_glob.state_dict()
# training
loss_train = [] # 训练过程损失
cv_loss, cv_acc = [], [] # 验证集损失,验证集准确率
val_loss_pre, counter = 0, 0
net_best = None
best_loss = None
val_acc_list, net_list = [], [] # 验证准确率列表,模型列表
if args.all_clients: # 对所有客户端全局聚合
print("Aggregation over all clients")
w_locals = [w_glob for i in range(args.num_users)] # 将全局模型权重复制num_users次放在w_locals中
for iter in range(args.epochs): # 每个epoch循环中
loss_locals = [] # 保存每个客户端的损失值
if not args.all_clients: # 如过不是所有客户端,则创建空列表
w_locals = []
m = max(int(args.frac * args.num_users), 1) # 根据参数 args.frac(比例) 和 args.num_users(总用户数量) 计算每轮要选择的用户数量m
idxs_users = np.random.choice(range(args.num_users), m, replace=False) # 随机选择m个不重复的用户索引,生成一个列表 idxs_users
for idx in idxs_users:# 对于每个用户
local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx]) # 进行本地训练
w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
if args.all_clients:
w_locals[idx] = copy.deepcopy(w)
else:
w_locals.append(copy.deepcopy(w))
loss_locals.append(copy.deepcopy(loss))
# update global weights
w_glob = FedAvg(w_locals) # 联邦聚合
# copy weight to net_glob
net_glob.load_state_dict(w_glob) #将更新后的全局模型权重 w_glob 加载到 net_glob 中
# print loss
loss_avg = sum(loss_locals) / len(loss_locals) # 平均损失
print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
loss_train.append(loss_avg)
# plot loss curve
plt.figure()
plt.plot(range(len(loss_train)), loss_train)
plt.ylabel('train_loss')
plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))
# testing
net_glob.eval()
acc_train, loss_train = test_img(net_glob, dataset_train, args)
acc_test, loss_test = test_img(net_glob, dataset_test, args)
print("Training accuracy: {:.2f}".format(acc_train))
print("Testing accuracy: {:.2f}".format(acc_test))
main_nn.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision import datasets, transforms
from utils.options import args_parser
from models.Nets import MLP, CNNMnist, CNNCifar
def test(net_g, data_loader):# 用于模型在测试集上进行评估的函数
# testing
net_g.eval() # 模型设置为评估模式
test_loss = 0 # 损失初始化
correct = 0 # 正确数量初始化
l = len(data_loader)
for idx, (data, target) in enumerate(data_loader): # 按批次索引从数据加载器中取数据以及对应的标签
data, target = data.to(args.device), target.to(args.device) # 将数据和标签转移到device
log_probs = net_g(data) # 将数据传入网络,得到预测的对数概率
test_loss += F.cross_entropy(log_probs, target).item() # 计算交叉熵损失并累加
y_pred = log_probs.data.max(1, keepdim=True)[1] #通过这段代码,可以得到log_probs张量中每一行的最大值对应的索引,即预测结果y_pred
correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() # 正确的类别累加
test_loss /= len(data_loader.dataset)
print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
test_loss, correct, len(data_loader.dataset),
100. * correct / len(data_loader.dataset)))
return correct, test_loss
if __name__ == '__main__':
# parse args
args = args_parser() # 解析命令行参数,并将返回的参数存储在 args 变量中。
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')# 设备选择
torch.manual_seed(args.seed) # 设置随机数种子,使用相同的种子将导致随机数生成器生成相同的随机数序列
# load dataset and split users
if args.dataset == 'mnist':
dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
img_size = dataset_train[0][0].shape # 将数据的第一个样本的形状存储在img_size中
elif args.dataset == 'cifar':
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset_train = datasets.CIFAR10('./data/cifar', train=True, transform=transform, target_transform=None, download=True)
img_size = dataset_train[0][0].shape
else:
exit('Error: unrecognized dataset')
# build model 根据输入的情况构建不同的模型
if args.model == 'cnn' and args.dataset == 'cifar':
net_glob = CNNCifar(args=args).to(args.device)
elif args.model == 'cnn' and args.dataset == 'mnist':
net_glob = CNNMnist(args=args).to(args.device)
elif args.model == 'mlp':
len_in = 1
for x in img_size:
len_in *= x
net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).to(args.device)
else:
exit('Error: unrecognized model')
print(net_glob)
# training
# 定义优化器 optimizer,使用随机梯度下降(SGD)算法,将模型 net_glob 的参数传递给优化器
optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum)
train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)
list_loss = [] # 存储每个epoch的损失
net_glob.train() # 设置为训练模式
for epoch in range(args.epochs):
batch_loss = [] # 批次损失
for batch_idx, (data, target) in enumerate(train_loader): # 训练数据集中安批次索引取数据和标签
data, target = data.to(args.device), target.to(args.device) # 转移到device
optimizer.zero_grad() # 优化器梯度置零
output = net_glob(data) # 输出
loss = F.cross_entropy(output, target) # 交叉熵损失
loss.backward() # 反向传播计算梯度
optimizer.step() # 根据计算的梯度更新模型的参数
if batch_idx % 50 == 0:# 设置打印频率
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
batch_loss.append(loss.item())
loss_avg = sum(batch_loss)/len(batch_loss) # 批损失均值
print('\nTrain loss:', loss_avg)
list_loss.append(loss_avg)
# plot loss
# 绘制训练损失随着 epoch 变化的折线图,并保存为图片
plt.figure()
plt.plot(range(len(list_loss)), list_loss)
plt.xlabel('epochs')
plt.ylabel('train loss')
plt.savefig('./log/nn_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs))
# testing
if args.dataset == 'mnist':
dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
elif args.dataset == 'cifar':
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset_test = datasets.CIFAR10('./data/cifar', train=False, transform=transform, target_transform=None, download=True)
test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
else:
exit('Error: unrecognized dataset')
print('test on', len(dataset_test), 'samples')
test_acc, test_loss = test(net_glob, test_loader)