import numpy as np
import torch
import os
classEarlyStopping:"""Early stops the training if validation loss doesn't improve after a given patience."""def__init__(self, save_path, patience=30, verbose=False, delta=0):"""
Args:
save_path : 模型保存的文件
patience (int): How long to wait after last time validation loss improved.
Default: 7,在loss来连续7次不下降时停止训练
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
"""
self.save_path = save_path
self.patience = patience
self.verbose = verbose
self.counter =0
self.best_score =None
self.early_stop =False
self.val_loss_min = np.Inf
self.delta = delta
def__call__(self, val_loss, model):
score =-val_loss
if self.best_score isNone:
self.best_score = score
self.save_checkpoint(val_loss, model)elif score < self.best_score + self.delta:
self.counter +=1print(f'EarlyStopping counter: {self.counter} out of {self.patience}')if self.counter >= self.patience:
self.early_stop =Trueelse:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter =0defsave_checkpoint(self, val_loss, model):"""Saves model when validation loss decrease."""if self.verbose:print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
path = os.path.join(self.save_path,'best_network.pth')
torch.save(model.state_dict(), path)# 存储最优模型参数
self.val_loss_min = val_loss
我们将上述代码放到一个单独的py文件中,然后在训练时调用:
from early_stopping import EarlyStopping
save_path ="./train-pth/"# 保存模型参数的路径
early_stopping = EarlyStopping(save_path)for epoch in tqdm(range(epoch_number)):# 具体的训练过程略过
early_stopping(train_loss, train_model)# 在每一轮训练结束之后调用if early_stopping.early_stop:print("Early stopping")break# 跳出训练