def save_baet_model(test_nat_correct, test_adv_correct,
best_nat_correct, best_adv_correct,
best_nat_model, best_adv_model,
best_nat_epoch, best_adv_epoch,
model, epoch):
if test_nat_correct > best_nat_correct:
best_nat_correct = test_nat_correct
best_nat_model = model.module.state_dict()
best_nat_epoch = epoch
if test_adv_correct > best_adv_correct:
best_adv_correct = test_adv_correct
best_adv_model = model.module.state_dict()
best_adv_epoch = epoch
return best_nat_correct, best_adv_correct, best_nat_epoch, best_adv_epoch, best_nat_model, best_adv_model
运行后随着epoch的增加,model的参数也会更新,相应的影响了best_adv_model、best_nat_model,因为model不是局部变量,没有随着子函数的结束而取消内存分配,重分配地址,导致变量一直指向这个变化的变量(python基础都忘了唉)
该代码段实现了一个功能,用于在训练过程中保存具有最佳性能的模型。它比较当前测试集上的自然语言任务(nat)和对抗性任务(adv)的正确率,并更新最佳模型的状态字典及对应的epoch。由于model是全局变量,其参数在epoch间持续更新,导致best_adv_model和best_nat_model也随model参数变化,而非保持最优状态。
1983

被折叠的 条评论
为什么被折叠?



