数据集读取方式
已经分好类别可以直接调用
dataset_train=datasets.ImageFolder('',transform_train)
dataset_test=datasets.ImageFolder('',transform_test)
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False)
名字中带有类别
先将图片进行分类,再调用
import os, glob
import torch
import torchvision
from PIL import Image
from torch.utils.data import Dataset
class data_set(Dataset):
def __init__(self, folder, transform=None, train=