在之前的项目中,数据比较多。都是按照7:1:2分为训练集、验证集和测试集,用验证集选出最优的模型,然后在测试集上进行测试。
但是这次项目的数据比较少,验证集和测试集只有十几个、二十几个,这样用来验证、测试模型不能具有很大意义。·
所以,在训练的时候想到了使用交叉验证的方法。只划分为训练集和测试集。在训练的时候,使用交叉验证选出结果最好的模型,在测试集上直接进行测试。
查阅了一些资料,最终在深度学习中实现了交叉验证。
训练时候的代码如下:
import os
import sys
import argparse
import time
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import KFold
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.nn import CosineEmbeddingLoss
from torch.utils.data import DataLoader, Subset
from torch.utils.tensorboard import SummaryWriter
#from models.Vit import CustomViT as model
from models.resnet50 import CustomResNet50 as model
# from models.ConvNetT import CustomConvNeXtT as model
from conf import settings
from utils import get_network, WarmUpLR, \
most_recent_folder, most_recent_weights, last_epoch, best_acc_weights
from dataset import CESM
# 指定保存数据的文件夹路径
folder_path = r"C:\Users\Administrator\Desktop\Breast\临时"
my_train_loss = []
my_eval_loss = []
my_train_correct= []
my_eval_correct = []
import time
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import KFold
my_train_loss = []
my_eval_loss = []
my_train_correct= []
my_eval_correct = []
class Trainer:
def __init__(self, net, CESMdata, loss_function, optimizer, device, batch_size):
self.net = net
self.CESMdata = CESMdata
self.loss_function = loss_function
self.optimizer = optimizer
self.device = device
self.batch_size = batch_size
self.my_train_loss = []
self.my_train_correct = []
self.my_val_loss = []
self.my_val_correct = []
def train(self, epoch):
start = time.time()
self.my_train_loss = []
self.my_train_correct = []
self.my_val_loss = []
self.my_val_correct = []
self.net.train()
kf = KFold(n_splits=5, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(kf.split(np.arange(len(self.CESMdata)))):
print(f'\nFold {fold + 1}/{kf.n_splits}')
# 创建训练集和验证集
train_subset = Subset(self.CESMdata, train_idx)
val_subset = Subset(self.CESMdata, val_idx)
train_load