大赛网址:http://challenge.xfyun.cn/topic/info?type=PET
赛事说明及要求请查看上面网址。数据下载需要实名认证,这里提供[百度网盘下载地址] (提取码:4hz2)。
赛题数据
赛题数据分训练集和测试集。
- 训练集分CN和AD两种类别的图像,每个类别一个文件夹共1千张,因此训练集共2千张图像。
- 测试集有1千张图像,当然是没有标签的,需要我们训练模型预测是AD还是CN,结果以CSV形式提交。大赛网站提供了提交模板。
可以发现训练集中图像尺寸各不一样,为了清楚地了解图像中都有哪些尺寸,可使用如下代码进行遍历:
import cv2
import glob
train_image_path = glob.glob(r"data\train\*\*.png")
sizeSet = set()
for i in train_image_path:
img = cv2.imread(i)
sizeSet.add(img.shape[:2])
print(sizeSet)
输出:
{(256, 256), (168, 168), (336, 336), (128, 128)}
因此,需要注意:赛题图像有四种大小:128*128、168*168、256*256、336*336。
0. 导入所需的库
# -*- coding: utf-8 -*-
import os, sys, glob, argparse
import pandas as pd
import numpy as np
from tqdm import tqdm
import time, datetime
import pdb, traceback
import cv2
# import imagehash
from PIL import Image
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
# from efficientnet_pytorch import EfficientNet
# model = EfficientNet.from_pretrained('efficientnet-b4')
import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
for i in (pd, np, torch):
print("{}: {}".format(i.__name__, i.__version__))
输出:
pandas: 0.24.2
numpy: 1.17.4
torch: 1.5.1+cu101
1. 训练代码:
# input dataset
train_jpg = np.array(glob.glob(r'data\train\*\*.png'))
class QRDataset(Dataset):
def __init__(self, train_jpg, transform=None):
self.train_jpg = train_jpg
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
start_time = time.time()
img = Image.open(self.train_jpg[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)