交叉验证
cross_val_lists = ['0', '1', '2', '3', '4']
for cross_val_index in cross_val_lists:
log.write('\n\ncross_val_index: ' + cross_val_index + '\n\n')
if True:
trainval_test(cross_val_index, sigma=30 * 0.1, lam=6 * 0.1)
包含图片即标签的文本文件
TRAIN_FILE = './Classification/NNEW_trainval_' + cross_val_index + '.txt'
TEST_FILE = './Classification/NNEW_test_' + cross_val_index + '.txt'
生成返回图片、标签及痤疮个数的数据集集合
normalize = transforms.Normalize(mean=[0.45815152, 0.361242, 0.29348266],
std=[0.2814769, 0.226306, 0.20132513])
dset_train = dataset_processing.DatasetProcessing(
DATA_PATH, TRAIN_FILE, transform=transforms.Compose([
transforms.Scale((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
RandomRotate(rotation_range=20),
normalize,
])) # return img, label, lesion
dset_test = dataset_processing.DatasetProcessing(
DATA_PATH, TEST_FILE, transform=transforms.Compose([
transforms.Scale((224, 224)),
transforms.ToTensor(),
normalize,
]))
打包batch
train_loader = DataLoader(dset_train,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
pin_memory=False) # 将训练数据集封装为一个可迭代的批处理对象
test_loader = DataLoader(dset_test,
batch_size=BATCH_SIZE_TEST,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=False)
120个epoch,每个epoch下,迭代所有的batch