nndl包的下载
链接:https://pan.baidu.com/s/1WvTC_O8WKImVyiqQMOZ-VA?pwd=abcd
提取码:abcd
--来自百度网盘超级会员V4的分享
1、数据集的构建
Moon1000数据集,其中训练集640条、验证集160条、测试集200条
该数据集的数据是从两个带噪音的弯月形状数据分布中采样得到,每个样本包含2个特征。
from nndl.dataset import make_moons
# 采样1000个样本
n_samples = 1000
X, y = make_moons(n_samples=n_samples, shuffle=True, noise=0.1)
num_train = 640
num_dev = 160
num_test = 200
X_train, y_train = X[:num_train], y[:num_train]
X_dev, y_dev = X[num_train:num_train + num_dev], y[num_train:num_train + num_dev]
X_test, y_test = X[num_train + num_dev:], y[num_train + num_dev:]
y_train = y_train.reshape([-1,1])
y_dev = y_dev.reshape([-1,1])
y_test = y_test.reshape([-1,1])
2、模型构建
2.1线性层算子
from nndl.op import Op
import torch
import numpy as np
# 实现线性层算子
class Linear(Op):
def __init__(self, input_size, output_size, name, weight_init=np.random.standard_normal, bias_init=torch.zeros):
self.params = {}
# 初始化权重
self.params['W'] = weight_init([input_size, output_size])
self.params['W'] = torch.as_tensor(self.params['W'],dtype=torch.float32)
# 初始化偏置
self.params['b'] = bias_init([1, output_size])
self.inputs = None
self.name = name
def forward(self, inputs):
self.inputs = inputs
outputs = torch.matmul(self.inputs, self.params['W']) + self.params['b']
return outputs
2.2Logistic层算子
class Logistic(Op):
def __init__(self):
self.inputs = None
self.outputs = None
def forward(self, inputs):
outputs = 1.0 / (1.0 + torch.exp(-inputs))
self.outputs = outputs
return outputs
2.3层的串行组合
# 实现一个两层前馈神经网络
class Model_MLP_L2(Op):
def __init__(self, input_size, hidden_size, output_size):
self.fc1 = Linear(input_size, hidden_size, name="fc1")
self.act_fn1 = Logistic()
self.fc2 = Linear(hidden_size, output_size, name="fc2")
self.act_fn2 = Logistic()
def __call__(self, X):
return self.forward(X)
def forward(self, X):
z1 = self.fc1(X)
a1 = self.act_fn1(z1)
z2 = self.fc2(a1)
a2 = self.act_fn2(z2)
return a2
测试一下
# 实例化模型
model = Model_MLP_L2(input_size=5, hidden_size=10, output_size=1)
# 随机生成1条长度为5的数据
X = torch.rand([1, 5])
result = model(X)
print ("result: ", result)
3、损失函数
# 实现交叉熵损失函数
class BinaryCrossEntropyLoss(Op):
def __init__(self):
self.predicts = None
self.labels = None
self.num = None
def __call__(self, predicts, labels):
return self.forward(predicts, labels)
def forward(self, predicts, labels):
self.predicts = predicts
self.labels = labels
self.num = self.predicts.shape[0]
loss = -1. / self.num * (torch.matmul(self.labels.t(), torch.log(self.predicts)) + torch.matmul((1-self.labels.t()), torch.log(1-self.predicts)))
loss = torch.squeeze(loss, 1)
return loss
4、模型优化
反向传播
损失函数
# 实现交叉熵损失函数
class BinaryCrossEntropyLoss(Op):
def __init__(self, model):
self.predicts = None
self.labels = None
self.num = None
self.model = model
def __call__(self, predicts, labels):
return self.forward(predicts, labels)
def forward(self, predicts, labels):
self.predicts = predicts
self.labels = labels
self.num = self.predicts.shape[0]
loss = -1. / self.num * (torch.matmul(self.labels.t(), torch.log(self.predicts))
+ torch.matmul((1 - self.labels.t()), torch.log(1 - self.predicts)))
loss = torch.squeeze(loss, axis=1)
return loss
def backward(self):
# 计算损失函数对模型预测的导数
loss_grad_predicts = -1.0 * (self.labels / self.predicts -
(1 - self.labels) / (1 - self.predicts)) / self.num
# 梯度反向传播
self.model.backward(loss_grad_predicts)
Logistic算子
class Logistic(Op):
def __init__(self):
self.inputs = None
self.outputs = None
self.params = None
def forward(self, inputs):
outputs = 1.0 / (1.0 + torch.exp(-inputs))
self.outputs = outputs
return outputs
def backward(self, grads):
# 计算Logistic激活函数对输入的导数
outputs_grad_inputs = torch.multiply(self.outputs, (1.0 - self.outputs))
return torch.multiply(grads,outputs_grad_inputs)
线性层
class Linear(Op):
def __init__(self, input_size, output_size, name, weight_init=np.random.standard_normal, bias_init=torch.zeros):
self.params = {}
self.params['W'] = weight_init([input_size, output_size])
self.params['W'] = torch.as_tensor(self.params['W'],dtype=torch.float32)
self.params['b'] = bias_init([1, output_size])
self.inputs = None
self.grads = {}
self.name = name
def forward(self, inputs):
self.inputs = inputs
outputs = torch.matmul(self.inputs, self.params['W']) + self.params['b']
return outputs
def backward(self, grads):
self.grads['W'] = torch.matmul(self.inputs.T, grads)
self.grads['b'] = torch.sum(grads, dim=0)
# 线性层输入的梯度
return torch.matmul(grads, self.params['W'].T)
整个网络
class Model_MLP_L2(Op):
def __init__(self, input_size, hidden_size, output_size):
# 线性层
self.fc1 = Linear(input_size, hidden_size, name="fc1")
# Logistic激活函数层
self.act_fn1 = Logistic()
self.fc2 = Linear(hidden_size, output_size, name="fc2")
self.act_fn2 = Logistic()
self.layers = [self.fc1, self.act_fn1, self.fc2, self.act_fn2]
def __call__(self, X):
return self.forward(X)
# 前向计算
def forward(self, X):
z1 = self.fc1(X)
a1 = self.act_fn1(z1)
z2 = self.fc2(a1)
a2 = self.act_fn2(z2)
return a2
# 反向计算
def backward(self, loss_grad_a2):
loss_grad_z2 = self.act_fn2.backward(loss_grad_a2)
loss_grad_a1 = self.fc2.backward(loss_grad_z2)
loss_grad_z1 = self.act_fn1.backward(loss_grad_a1)
loss_grad_inputs = self.fc1.backward(loss_grad_z1)
优化器
from nndl.opitimizer import Optimizer
class BatchGD(Optimizer):
def __init__(self, init_lr, model):
super(BatchGD, self).__init__(init_lr=init_lr, model=model)
def step(self):
# 参数更新
for layer in self.model.layers: # 遍历所有层
if isinstance(layer.params, dict):
for key in layer.params.keys():
layer.params[key] = layer.params[key] - self.init_lr * layer.grads[key]
5、完善Runner类:RunnerV2_1
import os
class RunnerV2_1(object):
def __init__(self, model, optimizer, metric, loss_fn, **kwargs):
self.model = model
self.optimizer = optimizer
self.loss_fn = loss_fn
self.metric = metric
# 记录训练过程中的评估指标变化情况
self.train_scores = []
self.dev_scores = []
# 记录训练过程中的评价指标变化情况
self.train_loss = []
self.dev_loss = []
def train(self, train_set, dev_set, **kwargs):
# 传入训练轮数,如果没有传入值则默认为0
num_epochs = kwargs.get("num_epochs", 0)
# 传入log打印频率,如果没有传入值则默认为100
log_epochs = kwargs.get("log_epochs", 100)
# 传入模型保存路径
save_dir = kwargs.get("save_dir", None)
# 记录全局最优指标
best_score = 0
# 进行num_epochs轮训练
for epoch in range(num_epochs):
X, y = train_set
# 获取模型预测
logits = self.model(X)
# 计算交叉熵损失
trn_loss = self.loss_fn(logits, y) # return a tensor
self.train_loss.append(trn_loss.item())
# 计算评估指标
trn_score = self.metric(logits, y).item()
self.train_scores.append(trn_score)
self.loss_fn.backward()
# 参数更新
self.optimizer.step()
dev_score, dev_loss = self.evaluate(dev_set)
# 如果当前指标为最优指标,保存该模型
if dev_score > best_score:
print(f"[Evaluate] best accuracy performence has been updated: {best_score:.5f} --> {dev_score:.5f}")
best_score = dev_score
if save_dir:
self.save_model(save_dir)
if log_epochs and epoch % log_epochs == 0:
print(f"[Train] epoch: {epoch}/{num_epochs}, loss: {trn_loss.item()}")
def evaluate(self, data_set):
X, y = data_set
# 计算模型输出
logits = self.model(X)
# 计算损失函数
loss = self.loss_fn(logits, y).item()
self.dev_loss.append(loss)
# 计算评估指标
score = self.metric(logits, y).item()
self.dev_scores.append(score)
return score, loss
def predict(self, X):
return self.model(X)
def save_model(self, save_dir):
# 对模型每层参数分别进行保存,保存文件名称与该层名称相同
for layer in self.model.layers: # 遍历所有层
if isinstance(layer.params, dict):
torch.save(layer.params, os.path.join(save_dir, layer.name+".pdparams"))
def load_model(self, model_dir):
# 获取所有层参数名称和保存路径之间的对应关系
model_file_names = os.listdir(model_dir)
name_file_dict = {}
for file_name in model_file_names:
name = file_name.replace(".pdparams", "")
name_file_dict[name] = os.path.join(model_dir, file_name)
# 加载每层参数
for layer in self.model.layers: # 遍历所有层
if isinstance(layer.params, dict):
name = layer.name
file_path = name_file_dict[name]
layer.params = torch.load(file_path, weights_only=True)
6、模型训练
from nndl.metric import accuracy
epoch_num = 1000
model_saved_dir = 'G:\大三上学期\深度学习\实验6'
# 输入层维度为2
input_size = 2
# 隐藏层维度为5
hidden_size = 4
# 输出层维度为1
output_size = 1
# 定义网络
model = Model_MLP_L2(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# 损失函数
loss_fn = BinaryCrossEntropyLoss(model)
# 优化器
learning_rate = 0.1
optimizer = BatchGD(learning_rate, model)
# 评价方法
metric = accuracy
# 实例化RunnerV2_1类,并传入训练配置
runner = RunnerV2_1(model, optimizer, metric, loss_fn)
runner.train([X_train, y_train], [X_dev, y_dev], num_epochs=epoch_num, log_epochs=50, save_dir=model_saved_dir)
运行结果
[Evaluate] best accuracy performence has been updated: 0.00000 --> 0.50000
[Train] epoch: 0/1000, loss: 0.7069255709648132
[Evaluate] best accuracy performence has been updated: 0.50000 --> 0.50625
[Evaluate] best accuracy performence has been updated: 0.50625 --> 0.53750
[Evaluate] best accuracy performence has been updated: 0.53750 --> 0.55625
[Evaluate] best accuracy performence has been updated: 0.55625 --> 0.58750
[Evaluate] best accuracy performence has been updated: 0.58750 --> 0.60000
[Evaluate] best accuracy performence has been updated: 0.60000 --> 0.61875
[Evaluate] best accuracy performence has been updated: 0.61875 --> 0.62500
[Evaluate] best accuracy performence has been updated: 0.62500 --> 0.64375
[Evaluate] best accuracy performence has been updated: 0.64375 --> 0.67500
[Evaluate] best accuracy performence has been updated: 0.67500 --> 0.70000
[Evaluate] best accuracy performence has been updated: 0.70000 --> 0.71250
[Evaluate] best accuracy performence has been updated: 0.71250 --> 0.74375
[Evaluate] best accuracy performence has been updated: 0.74375 --> 0.75625
[Evaluate] best accuracy performence has been updated: 0.75625 --> 0.76875
[Evaluate] best accuracy performence has been updated: 0.76875 --> 0.77500
[Evaluate] best accuracy performence has been updated: 0.77500 --> 0.78750
[Evaluate] best accuracy performence has been updated: 0.78750 --> 0.79375
[Evaluate] best accuracy performence has been updated: 0.79375 --> 0.80000
[Evaluate] best accuracy performence has been updated: 0.80000 --> 0.81250[Evaluate] best accuracy performence has been updated: 0.00000 --> 0.50000
[Train] epoch: 0/1000, loss: 0.7069255709648132
[Evaluate] best accuracy performence has been updated: 0.50000 --> 0.50625
[Evaluate] best accuracy performence has been updated: 0.50625 --> 0.53750
[Evaluate] best accuracy performence has been updated: 0.53750 --> 0.55625
[Evaluate] best accuracy performence has been updated: 0.55625 --> 0.58750
[Evaluate] best accuracy performence has been updated: 0.58750 --> 0.60000
[Evaluate] best accuracy performence has been updated: 0.60000 --> 0.61875
[Evaluate] best accuracy performence has been updated: 0.61875 --> 0.62500
[Evaluate] best accuracy performence has been updated: 0.62500 --> 0.64375
[Evaluate] best accuracy performence has been updated: 0.64375 --> 0.67500
[Evaluate] best accuracy performence has been updated: 0.67500 --> 0.70000
[Evaluate] best accuracy performence has been updated: 0.70000 --> 0.71250
[Evaluate] best accuracy performence has been updated: 0.71250 --> 0.74375
[Evaluate] best accuracy performence has been updated: 0.74375 --> 0.75625
[Evaluate] best accuracy performence has been updated: 0.75625 --> 0.76875
[Evaluate] best accuracy performence has been updated: 0.76875 --> 0.77500
[Evaluate] best accuracy performence has been updated: 0.77500 --> 0.78750
[Evaluate] best accuracy performence has been updated: 0.78750 --> 0.79375
[Evaluate] best accuracy performence has been updated: 0.79375 --> 0.80000
[Evaluate] best accuracy performence has been updated: 0.80000 --> 0.81250
[Evaluate] best accuracy performence has been updated: 0.81250 --> 0.81875
[Evaluate] best accuracy performence has been updated: 0.81875 --> 0.83125
[Evaluate] best accuracy performence has been updated: 0.83125 --> 0.84375
[Evaluate] best accuracy performence has been updated: 0.84375 --> 0.85000
[Evaluate] best accuracy performence has been updated: 0.85000 --> 0.85625
[Evaluate] best accuracy performence has been updated: 0.85625 --> 0.86250
[Train] epoch: 50/1000, loss: 0.582136332988739
[Evaluate] best accuracy performence has been updated: 0.86250 --> 0.86875[Evaluate] best accuracy performence has been updated: 0.00000 --> 0.50000
[Train] epoch: 0/1000, loss: 0.7069255709648132
[Evaluate] best accuracy performence has been updated: 0.50000 --> 0.50625
[Evaluate] best accuracy performence has been updated: 0.50625 --> 0.53750
[Evaluate] best accuracy performence has been updated: 0.53750 --> 0.55625
[Evaluate] best accuracy performence has been updated: 0.55625 --> 0.58750
[Evaluate] best accuracy performence has been updated: 0.58750 --> 0.60000
[Evaluate] best accuracy performence has been updated: 0.60000 --> 0.61875
[Evaluate] best accuracy performence has been updated: 0.61875 --> 0.62500
[Evaluate] best accuracy performence has been updated: 0.62500 --> 0.64375
[Evaluate] best accuracy performence has been updated: 0.64375 --> 0.67500
[Evaluate] best accuracy performence has been updated: 0.67500 --> 0.70000
[Evaluate] best accuracy performence has been updated: 0.70000 --> 0.71250
[Evaluate] best accuracy performence has been updated: 0.71250 --> 0.74375
[Evaluate] best accuracy performence has been updated: 0.74375 --> 0.75625
[Evaluate] best accuracy performence has been updated: 0.75625 --> 0.76875
[Evaluate] best accuracy performence has been updated: 0.76875 --> 0.77500
[Evaluate] best accuracy performence has been updated: 0.77500 --> 0.78750
[Evaluate] best accuracy performence has been updated: 0.78750 --> 0.79375
[Evaluate] best accuracy performence has been updated: 0.79375 --> 0.80000
[Evaluate] best accuracy performence has been updated: 0.80000 --> 0.81250
[Evaluate] best accuracy performence has been updated: 0.81250 --> 0.81875
[Evaluate] best accuracy performence has been updated: 0.81875 --> 0.83125
[Evaluate] best accuracy performence has been updated: 0.83125 --> 0.84375
[Evaluate] best accuracy performence has been updated: 0.84375 --> 0.85000
[Evaluate] best accuracy performence has been updated: 0.85000 --> 0.85625
[Evaluate] best accuracy performence has been updated: 0.85625 --> 0.86250
[Train] epoch: 50/1000, loss: 0.582136332988739
[Evaluate] best accuracy performence has been updated: 0.86250 --> 0.86875
[Evaluate] best accuracy performence has been updated: 0.86875 --> 0.87500
[Train] epoch: 100/1000, loss: 0.5028225779533386
[Train] epoch: 150/1000, loss: 0.4436149597167969[Evaluate] best accuracy performence has been updated: 0.00000 --> 0.50000
[Train] epoch: 0/1000, loss: 0.7069255709648132
[Evaluate] best accuracy performence has been updated: 0.50000 --> 0.50625
[Evaluate] best accuracy performence has been updated: 0.50625 --> 0.53750
[Evaluate] best accuracy performence has been updated: 0.53750 --> 0.55625
[Evaluate] best accuracy performence has been updated: 0.55625 --> 0.58750
[Evaluate] best accuracy performence has been updated: 0.58750 --> 0.60000
[Evaluate] best accuracy performence has been updated: 0.60000 --> 0.61875
[Evaluate] best accuracy performence has been updated: 0.61875 --> 0.62500
[Evaluate] best accuracy performence has been updated: 0.62500 --> 0.64375
[Evaluate] best accuracy performence has been updated: 0.64375 --> 0.67500
[Evaluate] best accuracy performence has been updated: 0.67500 --> 0.70000
[Evaluate] best accuracy performence has been updated: 0.70000 --> 0.71250
[Evaluate] best accuracy performence has been updated: 0.71250 --> 0.74375
[Evaluate] best accuracy performence has been updated: 0.74375 --> 0.75625
[Evaluate] best accuracy performence has been updated: 0.75625 --> 0.76875
[Evaluate] best accuracy performence has been updated: 0.76875 --> 0.77500
[Evaluate] best accuracy performence has been updated: 0.77500 --> 0.78750
[Evaluate] best accuracy performence has been updated: 0.78750 --> 0.79375
[Evaluate] best accuracy performence has been updated: 0.79375 --> 0.80000
[Evaluate] best accuracy performence has been updated: 0.80000 --> 0.81250
[Evaluate] best accuracy performence has been updated: 0.81250 --> 0.81875
[Evaluate] best accuracy performence has been updated: 0.81875 --> 0.83125
[Evaluate] best accuracy performence has been updated: 0.83125 --> 0.84375
[Evaluate] best accuracy performence has been updated: 0.84375 --> 0.85000
[Evaluate] best accuracy performence has been updated: 0.85000 --> 0.85625
[Evaluate] best accuracy performence has been updated: 0.85625 --> 0.86250
[Train] epoch: 50/1000, loss: 0.582136332988739
[Evaluate] best accuracy performence has been updated: 0.86250 --> 0.86875
[Evaluate] best accuracy performence has been updated: 0.86875 --> 0.87500
[Train] epoch: 100/1000, loss: 0.5028225779533386
[Train] epoch: 150/1000, loss: 0.4436149597167969
[Train] epoch: 200/1000, loss: 0.399962842464447[Evaluate] best accuracy performence has been updated: 0.00000 --> 0.50000
[Train] epoch: 0/1000, loss: 0.7069255709648132
[Evaluate] best accuracy performence has been updated: 0.50000 --> 0.50625
[Evaluate] best accuracy performence has been updated: 0.50625 --> 0.53750
[Evaluate] best accuracy performence has been updated: 0.53750 --> 0.55625
[Evaluate] best accuracy performence has been updated: 0.55625 --> 0.58750
[Evaluate] best accuracy performence has been updated: 0.58750 --> 0.60000
[Evaluate] best accuracy performence has been updated: 0.60000 --> 0.61875
[Evaluate] best accuracy performence has been updated: 0.61875 --> 0.62500
[Evaluate] best accuracy performence has been updated: 0.62500 --> 0.64375
[Evaluate] best accuracy performence has been updated: 0.64375 --> 0.67500
[Evaluate] best accuracy performence has been updated: 0.67500 --> 0.70000
[Evaluate] best accuracy performence has been updated: 0.70000 --> 0.71250
[Evaluate] best accuracy performence has been updated: 0.71250 --> 0.74375
[Evaluate] best accuracy performence has been updated: 0.74375 --> 0.75625
[Evaluate] best accuracy performence has been updated: 0.75625 --> 0.76875
[Evaluate] best accuracy performence has been updated: 0.76875 --> 0.77500
[Evaluate] best accuracy performence has been updated: 0.77500 --> 0.78750
[Evaluate] best accuracy performence has been updated: 0.78750 --> 0.79375
[Evaluate] best accuracy performence has been updated: 0.79375 --> 0.80000
[Evaluate] best accuracy performence has been updated: 0.80000 --> 0.81250
[Evaluate] best accuracy performence has been updated: 0.81250 --> 0.81875
[Evaluate] best accuracy performence has been updated: 0.81875 --> 0.83125
[Evaluate] best accuracy performence has been updated: 0.83125 --> 0.84375
[Evaluate] best accuracy performence has been updated: 0.84375 --> 0.85000
[Evaluate] best accuracy performence has been updated: 0.85000 --> 0.85625
[Evaluate] best accuracy performence has been updated: 0.85625 --> 0.86250
[Train] epoch: 50/1000, loss: 0.582136332988739
[Evaluate] best accuracy performence has been updated: 0.86250 --> 0.86875
[Evaluate] best accuracy performence has been updated: 0.86875 --> 0.87500
[Train] epoch: 100/1000, loss: 0.5028225779533386
[Train] epoch: 150/1000, loss: 0.4436149597167969
[Train] epoch: 200/1000, loss: 0.399962842464447
[Evaluate] best accuracy performence has been updated: 0.87500 --> 0.88125[Evaluate] best accuracy performence has been updated: 0.00000 --> 0.50000
[Train] epoch: 0/1000, loss: 0.7069255709648132
[Evaluate] best accuracy performence has been updated: 0.50000 --> 0.50625
[Evaluate] best accuracy performence has been updated: 0.50625 --> 0.53750
[Evaluate] best accuracy performence has been updated: 0.53750 --> 0.55625
[Evaluate] best accuracy performence has been updated: 0.55625 --> 0.58750
[Evaluate] best accuracy performence has been updated: 0.58750 --> 0.60000
[Evaluate] best accuracy performence has been updated: 0.60000 --> 0.61875
[Evaluate] best accuracy performence has been updated: 0.61875 --> 0.62500
[Evaluate] best accuracy performence has been updated: 0.62500 --> 0.64375
[Evaluate] best accuracy performence has been updated: 0.64375 --> 0.67500
[Evaluate] best accuracy performence has been updated: 0.67500 --> 0.70000
[Evaluate] best accuracy performence has been updated: 0.70000 --> 0.71250
[Evaluate] best accuracy performence has been updated: 0.71250 --> 0.74375
[Evaluate] best accuracy performence has been updated: 0.74375 --> 0.75625
[Evaluate] best accuracy performence has been updated: 0.75625 --> 0.76875
[Evaluate] best accuracy performence has been updated: 0.76875 --> 0.77500
[Evaluate] best accuracy performence has been updated: 0.77500 --> 0.78750
[Evaluate] best accuracy performence has been updated: 0.78750 --> 0.79375
[Evaluate] best accuracy performence has been updated: 0.79375 --> 0.80000
[Evaluate] best accuracy performence has been updated: 0.80000 --> 0.81250
[Evaluate] best accuracy performence has been updated: 0.81250 --> 0.81875
[Evaluate] best accuracy performence has been updated: 0.81875 --> 0.83125
[Evaluate] best accuracy performence has been updated: 0.83125 --> 0.84375
[Evaluate] best accuracy performence has been updated: 0.84375 --> 0.85000
[Evaluate] best accuracy performence has been updated: 0.85000 --> 0.85625
[Evaluate] best accuracy performence has been updated: 0.85625 --> 0.86250
[Train] epoch: 50/1000, loss: 0.582136332988739
[Evaluate] best accuracy performence has been updated: 0.86250 --> 0.86875
[Evaluate] best accuracy performence has been updated: 0.86875 --> 0.87500
[Train] epoch: 100/1000, loss: 0.5028225779533386
[Train] epoch: 150/1000, loss: 0.4436149597167969
[Train] epoch: 200/1000, loss: 0.399962842464447
[Evaluate] best accuracy performence has been updated: 0.87500 --> 0.88125
[Train] epoch: 250/1000, loss: 0.36777132749557495
[Train] epoch: 300/1000, loss: 0.3435978889465332
[Evaluate] best accuracy performence has been updated: 0.88125 --> 0.88750
[Train] epoch: 350/1000, loss: 0.32495132088661194[Evaluate] best accuracy performence has been updated: 0.00000 --> 0.50000
[Train] epoch: 0/1000, loss: 0.7069255709648132
[Evaluate] best accuracy performence has been updated: 0.50000 --> 0.50625
[Evaluate] best accuracy performence has been updated: 0.50625 --> 0.53750
[Evaluate] best accuracy performence has been updated: 0.53750 --> 0.55625
[Evaluate] best accuracy performence has been updated: 0.55625 --> 0.58750
[Evaluate] best accuracy performence has been updated: 0.58750 --> 0.60000
[Evaluate] best accuracy performence has been updated: 0.60000 --> 0.61875
[Evaluate] best accuracy performence has been updated: 0.61875 --> 0.62500
[Evaluate] best accuracy performence has been updated: 0.62500 --> 0.64375
[Evaluate] best accuracy performence has been updated: 0.64375 --> 0.67500
[Evaluate] best accuracy performence has been updated: 0.67500 --> 0.70000
[Evaluate] best accuracy performence has been updated: 0.70000 --> 0.71250
[Evaluate] best accuracy performence has been updated: 0.71250 --> 0.74375
[Evaluate] best accuracy performence has been updated: 0.74375 --> 0.75625
[Evaluate] best accuracy performence has been updated: 0.75625 --> 0.76875
[Evaluate] best accuracy performence has been updated: 0.76875 --> 0.77500
[Evaluate] best accuracy performence has been updated: 0.77500 --> 0.78750
[Evaluate] best accuracy performence has been updated: 0.78750 --> 0.79375
[Evaluate] best accuracy performence has been updated: 0.79375 --> 0.80000
[Evaluate] best accuracy performence has been updated: 0.80000 --> 0.81250
[Evaluate] best accuracy performence has been updated: 0.81250 --> 0.81875
[Evaluate] best accuracy performence has been updated: 0.81875 --> 0.83125
[Evaluate] best accuracy performence has been updated: 0.83125 --> 0.84375
[Evaluate] best accuracy performence has been updated: 0.84375 --> 0.85000
[Evaluate] best accuracy performence has been updated: 0.85000 --> 0.85625
[Evaluate] best accuracy performence has been updated: 0.85625 --> 0.86250
[Train] epoch: 50/1000, loss: 0.582136332988739
[Evaluate] best accuracy performence has been updated: 0.86250 --> 0.86875
[Evaluate] best accuracy performence has been updated: 0.86875 --> 0.87500
[Train] epoch: 100/1000, loss: 0.5028225779533386
[Train] epoch: 150/1000, loss: 0.4436149597167969
[Train] epoch: 200/1000, loss: 0.399962842464447
[Evaluate] best accuracy performence has been updated: 0.87500 --> 0.88125
[Train] epoch: 250/1000, loss: 0.36777132749557495
[Train] epoch: 300/1000, loss: 0.3435978889465332
[Evaluate] best accuracy performence has been updated: 0.88125 --> 0.88750
[Train] epoch: 350/1000, loss: 0.32495132088661194
[Train] epoch: 400/1000, loss: 0.31017178297042847[Evaluate] best accuracy performence has been updated: 0.00000 --> 0.50000
[Train] epoch: 0/1000, loss: 0.7069255709648132
[Evaluate] best accuracy performence has been updated: 0.50000 --> 0.50625
[Evaluate] best accuracy performence has been updated: 0.50625 --> 0.53750
[Evaluate] best accuracy performence has been updated: 0.53750 --> 0.55625
[Evaluate] best accuracy performence has been updated: 0.55625 --> 0.58750
[Evaluate] best accuracy performence has been updated: 0.58750 --> 0.60000
[Evaluate] best accuracy performence has been updated: 0.60000 --> 0.61875
[Evaluate] best accuracy performence has been updated: 0.61875 --> 0.62500
[Evaluate] best accuracy performence has been updated: 0.62500 --> 0.64375
[Evaluate] best accuracy performence has been updated: 0.64375 --> 0.67500
[Evaluate] best accuracy performence has been updated: 0.67500 --> 0.70000
[Evaluate] best accuracy performence has been updated: 0.70000 --> 0.71250
[Evaluate] best accuracy performence has been updated: 0.71250 --> 0.74375
[Evaluate] best accuracy performence has been updated: 0.74375 --> 0.75625
[Evaluate] best accuracy performence has been updated: 0.75625 --> 0.76875
[Evaluate] best accuracy performence has been updated: 0.76875 --> 0.77500
[Evaluate] best accuracy performence has been updated: 0.77500 --> 0.78750
[Evaluate] best accuracy performence has been updated: 0.78750 --> 0.79375
[Evaluate] best accuracy performence has been updated: 0.79375 --> 0.80000
[Evaluate] best accuracy performence has been updated: 0.80000 --> 0.81250
[Evaluate] best accuracy performence has been updated: 0.81250 --> 0.81875
[Evaluate] best accuracy performence has been updated: 0.81875 --> 0.83125
[Evaluate] best accuracy performence has been updated: 0.83125 --> 0.84375
[Evaluate] best accuracy performence has been updated: 0.84375 --> 0.85000
[Evaluate] best accuracy performence has been updated: 0.85000 --> 0.85625
[Evaluate] best accuracy performence has been updated: 0.85625 --> 0.86250
[Train] epoch: 50/1000, loss: 0.582136332988739
[Evaluate] best accuracy performence has been updated: 0.86250 --> 0.86875
[Evaluate] best accuracy performence has been updated: 0.86875 --> 0.87500
[Train] epoch: 100/1000, loss: 0.5028225779533386
[Train] epoch: 150/1000, loss: 0.4436149597167969
[Train] epoch: 200/1000, loss: 0.399962842464447
[Evaluate] best accuracy performence has been updated: 0.87500 --> 0.88125
[Train] epoch: 250/1000, loss: 0.36777132749557495
[Train] epoch: 300/1000, loss: 0.3435978889465332
[Evaluate] best accuracy performence has been updated: 0.88125 --> 0.88750
[Train] epoch: 350/1000, loss: 0.32495132088661194
[Train] epoch: 400/1000, loss: 0.31017178297042847
[Evaluate] best accuracy performence has been updated: 0.88750 --> 0.89375
[Train] epoch: 450/1000, loss: 0.2981906533241272
[Train] epoch: 500/1000, loss: 0.28831884264945984
[Train] epoch: 550/1000, loss: 0.28009912371635437[Evaluate] best accuracy performence has been updated: 0.00000 --> 0.50000
[Train] epoch: 0/1000, loss: 0.7069255709648132
[Evaluate] best accuracy performence has been updated: 0.50000 --> 0.50625
[Evaluate] best accuracy performence has been updated: 0.50625 --> 0.53750
[Evaluate] best accuracy performence has been updated: 0.53750 --> 0.55625
[Evaluate] best accuracy performence has been updated: 0.55625 --> 0.58750
[Evaluate] best accuracy performence has been updated: 0.58750 --> 0.60000
[Evaluate] best accuracy performence has been updated: 0.60000 --> 0.61875
[Evaluate] best accuracy performence has been updated: 0.61875 --> 0.62500
[Evaluate] best accuracy performence has been updated: 0.62500 --> 0.64375
[Evaluate] best accuracy performence has been updated: 0.64375 --> 0.67500
[Evaluate] best accuracy performence has been updated: 0.67500 --> 0.70000
[Evaluate] best accuracy performence has been updated: 0.70000 --> 0.71250
[Evaluate] best accuracy performence has been updated: 0.71250 --> 0.74375
[Evaluate] best accuracy performence has been updated: 0.74375 --> 0.75625
[Evaluate] best accuracy performence has been updated: 0.75625 --> 0.76875
[Evaluate] best accuracy performence has been updated: 0.76875 --> 0.77500
[Evaluate] best accuracy performence has been updated: 0.77500 --> 0.78750
[Evaluate] best accuracy performence has been updated: 0.78750 --> 0.79375
[Evaluate] best accuracy performence has been updated: 0.79375 --> 0.80000
[Evaluate] best accuracy performence has been updated: 0.80000 --> 0.81250
[Evaluate] best accuracy performence has been updated: 0.81250 --> 0.81875
[Evaluate] best accuracy performence has been updated: 0.81875 --> 0.83125
[Evaluate] best accuracy performence has been updated: 0.83125 --> 0.84375
[Evaluate] best accuracy performence has been updated: 0.84375 --> 0.85000
[Evaluate] best accuracy performence has been updated: 0.85000 --> 0.85625
[Evaluate] best accuracy performence has been updated: 0.85625 --> 0.86250
[Train] epoch: 50/1000, loss: 0.582136332988739
[Evaluate] best accuracy performence has been updated: 0.86250 --> 0.86875
[Evaluate] best accuracy performence has been updated: 0.86875 --> 0.87500
[Train] epoch: 100/1000, loss: 0.5028225779533386
[Train] epoch: 150/1000, loss: 0.4436149597167969
[Train] epoch: 200/1000, loss: 0.399962842464447
[Evaluate] best accuracy performence has been updated: 0.87500 --> 0.88125
[Train] epoch: 250/1000, loss: 0.36777132749557495
[Train] epoch: 300/1000, loss: 0.3435978889465332
[Evaluate] best accuracy performence has been updated: 0.88125 --> 0.88750
[Train] epoch: 350/1000, loss: 0.32495132088661194
[Train] epoch: 400/1000, loss: 0.31017178297042847
[Evaluate] best accuracy performence has been updated: 0.88750 --> 0.89375
[Train] epoch: 450/1000, loss: 0.2981906533241272
[Train] epoch: 500/1000, loss: 0.28831884264945984
[Train] epoch: 550/1000, loss: 0.28009912371635437
[Train] epoch: 600/1000, loss: 0.27321189641952515[Evaluate] best accuracy performence has been updated: 0.00000 --> 0.50000
[Train] epoch: 0/1000, loss: 0.7069255709648132
[Evaluate] best accuracy performence has been updated: 0.50000 --> 0.50625
[Evaluate] best accuracy performence has been updated: 0.50625 --> 0.53750
[Evaluate] best accuracy performence has been updated: 0.53750 --> 0.55625
[Evaluate] best accuracy performence has been updated: 0.55625 --> 0.58750
[Evaluate] best accuracy performence has been updated: 0.58750 --> 0.60000
[Evaluate] best accuracy performence has been updated: 0.60000 --> 0.61875
[Evaluate] best accuracy performence has been updated: 0.61875 --> 0.62500
[Evaluate] best accuracy performence has been updated: 0.62500 --> 0.64375
[Evaluate] best accuracy performence has been updated: 0.64375 --> 0.67500
[Evaluate] best accuracy performence has been updated: 0.67500 --> 0.70000
[Evaluate] best accuracy performence has been updated: 0.70000 --> 0.71250
[Evaluate] best accuracy performence has been updated: 0.71250 --> 0.74375
[Evaluate] best accuracy performence has been updated: 0.74375 --> 0.75625
[Evaluate] best accuracy performence has been updated: 0.75625 --> 0.76875
[Evaluate] best accuracy performence has been updated: 0.76875 --> 0.77500
[Evaluate] best accuracy performence has been updated: 0.77500 --> 0.78750
[Evaluate] best accuracy performence has been updated: 0.78750 --> 0.79375
[Evaluate] best accuracy performence has been updated: 0.79375 --> 0.80000
[Evaluate] best accuracy performence has been updated: 0.80000 --> 0.81250
[Evaluate] best accuracy performence has been updated: 0.81250 --> 0.81875
[Evaluate] best accuracy performence has been updated: 0.81875 --> 0.83125
[Evaluate] best accuracy performence has been updated: 0.83125 --> 0.84375
[Evaluate] best accuracy performence has been updated: 0.84375 --> 0.85000
[Evaluate] best accuracy performence has been updated: 0.85000 --> 0.85625
[Evaluate] best accuracy performence has been updated: 0.85625 --> 0.86250
[Train] epoch: 50/1000, loss: 0.582136332988739
[Evaluate] best accuracy performence has been updated: 0.86250 --> 0.86875
[Evaluate] best accuracy performence has been updated: 0.86875 --> 0.87500
[Train] epoch: 100/1000, loss: 0.5028225779533386
[Train] epoch: 150/1000, loss: 0.4436149597167969
[Train] epoch: 200/1000, loss: 0.399962842464447
[Evaluate] best accuracy performence has been updated: 0.87500 --> 0.88125
[Train] epoch: 250/1000, loss: 0.36777132749557495
[Train] epoch: 300/1000, loss: 0.3435978889465332
[Evaluate] best accuracy performence has been updated: 0.88125 --> 0.88750
[Train] epoch: 350/1000, loss: 0.32495132088661194
[Train] epoch: 400/1000, loss: 0.31017178297042847
[Evaluate] best accuracy performence has been updated: 0.88750 --> 0.89375
[Train] epoch: 450/1000, loss: 0.2981906533241272
[Train] epoch: 500/1000, loss: 0.28831884264945984
[Train] epoch: 550/1000, loss: 0.28009912371635437
[Train] epoch: 600/1000, loss: 0.27321189641952515
[Train] epoch: 650/1000, loss: 0.2674194276332855
[Train] epoch: 700/1000, loss: 0.26253607869148254
[Train] epoch: 750/1000, loss: 0.25841012597084045[Evaluate] best accuracy performence has been updated: 0.00000 --> 0.50000
[Train] epoch: 0/1000, loss: 0.7069255709648132
[Evaluate] best accuracy performence has been updated: 0.50000 --> 0.50625
[Evaluate] best accuracy performence has been updated: 0.50625 --> 0.53750
[Evaluate] best accuracy performence has been updated: 0.53750 --> 0.55625
[Evaluate] best accuracy performence has been updated: 0.55625 --> 0.58750
[Evaluate] best accuracy performence has been updated: 0.58750 --> 0.60000
[Evaluate] best accuracy performence has been updated: 0.60000 --> 0.61875
[Evaluate] best accuracy performence has been updated: 0.61875 --> 0.62500
[Evaluate] best accuracy performence has been updated: 0.62500 --> 0.64375
[Evaluate] best accuracy performence has been updated: 0.64375 --> 0.67500
[Evaluate] best accuracy performence has been updated: 0.67500 --> 0.70000
[Evaluate] best accuracy performence has been updated: 0.70000 --> 0.71250
[Evaluate] best accuracy performence has been updated: 0.71250 --> 0.74375
[Evaluate] best accuracy performence has been updated: 0.74375 --> 0.75625
[Evaluate] best accuracy performence has been updated: 0.75625 --> 0.76875
[Evaluate] best accuracy performence has been updated: 0.76875 --> 0.77500
[Evaluate] best accuracy performence has been updated: 0.77500 --> 0.78750
[Evaluate] best accuracy performence has been updated: 0.78750 --> 0.79375
[Evaluate] best accuracy performence has been updated: 0.79375 --> 0.80000
[Evaluate] best accuracy performence has been updated: 0.80000 --> 0.81250
[Evaluate] best accuracy performence has been updated: 0.81250 --> 0.81875
[Evaluate] best accuracy performence has been updated: 0.81875 --> 0.83125
[Evaluate] best accuracy performence has been updated: 0.83125 --> 0.84375
[Evaluate] best accuracy performence has been updated: 0.84375 --> 0.85000
[Evaluate] best accuracy performence has been updated: 0.85000 --> 0.85625
[Evaluate] best accuracy performence has been updated: 0.85625 --> 0.86250
[Train] epoch: 50/1000, loss: 0.582136332988739
[Evaluate] best accuracy performence has been updated: 0.86250 --> 0.86875
[Evaluate] best accuracy performence has been updated: 0.86875 --> 0.87500
[Train] epoch: 100/1000, loss: 0.5028225779533386
[Train] epoch: 150/1000, loss: 0.4436149597167969
[Train] epoch: 200/1000, loss: 0.399962842464447
[Evaluate] best accuracy performence has been updated: 0.87500 --> 0.88125
[Train] epoch: 250/1000, loss: 0.36777132749557495
[Train] epoch: 300/1000, loss: 0.3435978889465332
[Evaluate] best accuracy performence has been updated: 0.88125 --> 0.88750
[Train] epoch: 350/1000, loss: 0.32495132088661194
[Train] epoch: 400/1000, loss: 0.31017178297042847
[Evaluate] best accuracy performence has been updated: 0.88750 --> 0.89375
[Train] epoch: 450/1000, loss: 0.2981906533241272
[Train] epoch: 500/1000, loss: 0.28831884264945984
[Train] epoch: 550/1000, loss: 0.28009912371635437
[Train] epoch: 600/1000, loss: 0.27321189641952515
[Train] epoch: 650/1000, loss: 0.2674194276332855
[Train] epoch: 700/1000, loss: 0.26253607869148254
[Train] epoch: 750/1000, loss: 0.25841012597084045
[Train] epoch: 800/1000, loss: 0.25491634011268616
[Train] epoch: 850/1000, loss: 0.2519505023956299
[Train] epoch: 900/1000, loss: 0.24942593276500702
[Train] epoch: 950/1000, loss: 0.24727044999599457
可视化观察训练集与验证集的损失函数变化情况。
import matplotlib.pyplot as plt
# 打印训练集和验证集的损失
plt.figure()
plt.plot(range(epoch_num), runner.train_loss, color="#8E004D", label="Train loss")
plt.plot(range(epoch_num), runner.dev_loss, color="#E20079", linestyle='--', label="Dev loss")
plt.xlabel("epoch", fontsize='x-large')
plt.ylabel("loss", fontsize='x-large')
plt.legend(fontsize='large')
plt.savefig('fw-loss2.pdf')
plt.show()
7、性能评价
# 加载训练好的模型
runner.load_model(model_saved_dir)
# 在测试集上对模型进行评价
score, loss = runner.evaluate([X_test, y_test])
print("[Test] score/loss: {:.4f}/{:.4f}".format(score, loss))
import math
# 均匀生成40000个数据点
x1, x2 = torch.meshgrid(torch.linspace(-math.pi, math.pi, 200), torch.linspace(-math.pi, math.pi, 200), indexing='ij')
x = torch.stack([torch.flatten(x1), torch.flatten(x2)], 1)
# 预测对应类别
y = runner.predict(x)
# y = torch.squeeze(torch.as_tensor(torch.can_cast((y>=0.5).dtype,torch.float32)))
# 绘制类别区域
plt.ylabel('x2')
plt.xlabel('x1')
plt.scatter(x[:,0].tolist(), x[:,1].tolist(), c=y.tolist(), cmap=plt.cm.Spectral)
plt.scatter(X_train[:, 0].tolist(), X_train[:, 1].tolist(), marker='*', c=torch.squeeze(y_train,-1).tolist())
plt.scatter(X_dev[:, 0].tolist(), X_dev[:, 1].tolist(), marker='*', c=torch.squeeze(y_dev,-1).tolist())
plt.scatter(X_test[:, 0].tolist(), X_test[:, 1].tolist(), marker='*', c=torch.squeeze(y_test,-1).tolist())
plt.show()