深度学习使用交叉验证(2)

在之前的项目中,数据比较多。都是按照7:1:2分为训练集、验证集和测试集,用验证集选出最优的模型,然后在测试集上进行测试。

但是这次项目的数据比较少,验证集和测试集只有十几个、二十几个,这样用来验证、测试模型不能具有很大意义。·

所以,在训练的时候想到了使用交叉验证的方法。只划分为训练集和测试集。在训练的时候,使用交叉验证选出结果最好的模型,在测试集上直接进行测试。

查阅了一些资料,最终在深度学习中实现了交叉验证。

训练时候的代码如下:

import os
import sys
import argparse
import time
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import KFold
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.nn import CosineEmbeddingLoss

from torch.utils.data import DataLoader, Subset
from torch.utils.tensorboard import SummaryWriter
#from models.Vit import CustomViT as model
from models.resnet50 import CustomResNet50 as model
# from models.ConvNetT import CustomConvNeXtT as model
from conf import settings
from utils import get_network,  WarmUpLR, \
    most_recent_folder, most_recent_weights, last_epoch, best_acc_weights
from dataset import CESM

# 指定保存数据的文件夹路径
folder_path = r"C:\Users\Administrator\Desktop\Breast\临时"

my_train_loss = []
my_eval_loss = []
my_train_correct= []
my_eval_correct = []
import time
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import KFold


my_train_loss = []
my_eval_loss = []
my_train_correct= []
my_eval_correct = []
class Trainer:
    def __init__(self, net, CESMdata, loss_function, optimizer, device, batch_size):
        self.net = net
        self.CESMdata = CESMdata
        self.loss_function = loss_function
        self.optimizer = optimizer
        self.device = device
        self.batch_size = batch_size
        self.my_train_loss = []
        self.my_train_correct = []
        self.my_val_loss = []
        self.my_val_correct = []

    def train(self, epoch):
        start = time.time()
        self.my_train_loss = []
        self.my_train_correct = []
        self.my_val_loss = []
        self.my_val_correct = []

        self.net.train()
        kf = KFold(n_splits=5, shuffle=True, random_state=42)

        for fold, (train_idx, val_idx) in enumerate(kf.split(np.arange(len(self.CESMdata)))):


            print(f'\nFold {fold + 1}/{kf.n_splits}')

            # 创建训练集和验证集
            train_subset = Subset(self.CESMdata, train_idx)
            val_subset = Subset(self.CESMdata, val_idx)

            train_load
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值