深度学习--早停策略

在训练时,有时候会无法把握训练次数的设定,此时可以添加一个早停策略,使得训练过程中,当loss不再下降时,停止训练。希望对你有所帮助,下面是代码:

import numpy as np
import torch
import os


class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, save_path, patience=30, verbose=False, delta=0):
        """
        Args:
            save_path : 模型保存的文件
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7,在loss来连续7次不下降时停止训练
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.save_path = save_path
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        """Saves model when validation loss decrease."""
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        path = os.path.join(self.save_path, 'best_network.pth')
        torch.save(model.state_dict(), path)  # 存储最优模型参数
        self.val_loss_min = val_loss

我们将上述代码放到一个单独的py文件中,然后在训练时调用:

from early_stopping import EarlyStopping
save_path = "./train-pth/"  # 保存模型参数的路径
early_stopping = EarlyStopping(save_path)
for epoch in tqdm(range(epoch_number)):
	# 具体的训练过程略过
    early_stopping(train_loss, train_model) # 在每一轮训练结束之后调用
    if early_stopping.early_stop:
        print("Early stopping")
        break  # 跳出训练
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值