pytorch-Dataset自定义

处理自定义的数据集

1、pytorch 官方模板

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

2、自定义数据集(以Kaggle-ClassifyLeaves为例)

# 定义dataset
class LeavesData(Dataset):
    def __init__(self , csv_path , img_path , mode = 'train' , valid_ratio = 0.2 , resize_height = 256 , resize_weirgt = 256):
        '''
        Args:
            csv_path(string): csv文件路径 
            img_path(string): 图像文件夹所在路径
            mode(string): 训练模式,测试模式
            valid_ratio(float) : 验证集比例
        '''

        #对原图片尺寸进行统一
        self.resize_height = resize_height
        self.resize_weirgt = resize_weirgt

        self.file_path = img_path
        self.mode = mode

        #读取csv文件
        self.data_info = pd.read_csv(csv_path,header=None) # 不要读取表头
        #计算len
        self.data_len = len(self.data_info.index)-1
        #train_len
        self.train_len = int(self.data_len*(1-valid_ratio))

        if mode == 'train':
            #train图像的名称,label
            self.train_image = np.asarray(self.data_info.iloc[1:self.train_len,0])
            self.train_label = np.asarray(self.data_info.iloc[1:self.train_len,1])
            self.image_arr = self.train_image
            self.labe_arr = self.train_label

        elif mode == 'valid':
            self.valid_image = np.asarray(self.data_info.iloc[self.train_len:,0])
            self.valid_label = np.asarray(self.data_info.iloc[self.train_len:,1])
            self.image_arr = self.valid_image
            self.labe_arr = self.valid_label

        elif mode == 'test':
            self.test_image = np.asarray(self.data_info.iloc[1:,0])
            self.image_arr = self.test_image

        self.real_len = len(self.image_arr)

        print('Finished reading the {} set of Leaves Dataset {} samples found'.format(mode,self.real_len))
        
    def __getitem__(self,index):
        single_image_name  = self.image_arr[index]
        
        # 读到的是文件名,要读取图像文件
        img_as_img = Image.open(self.file_path + single_image_name)
        
        if self.mode == 'train':
            transform = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                transforms.ToTensor()
            ])
        
        else:
            transform = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor()
            ])

        img_as_img = transform(img_as_img)
        
        if self.mode == 'test':
            return img_as_img
        
        else:
            #返回图像label
            label = self.labe_arr[index]
            number_label = class_to_num[label]
        
            return img_as_img , number_label
    
    def __len__(self):
        return self.real_len    
       
train_path = r'E:/Learn_MechineLearning_DeepLearning/pytorch-Learning/Kaggle/Classify-leaves/classify-leaves/train.csv'
test_path = r'E:/Learn_MechineLearning_DeepLearning/pytorch-Learning/Kaggle/Classify-leaves/classify-leaves/test.csv'

img_path = r'E:/Learn_MechineLearning_DeepLearning/pytorch-Learning/Kaggle/Classify-leaves/classify-leaves/'

# 定义dataloader
train_loader = torch.utils.data.DataLoader(
    dataset = train_dataset,
    batch_size = 16,
    shuffle = True
)

val_loader = torch.utils.data.DataLoader(
    dataset = val_dataset,
    batch_size = 16,
    shuffle = True
)

test_loader = torch.utils.data.DataLoader(
    dataset = test_dataset,
    batch_size = 16,
    shuffle = True
)

# 显示一下前训练集前10张的图片
cnt = 0
for img , label in train_loader:
    img = img[0]   # 若出错可改成img = image[0].cuda 试试
    img = img.detach().numpy() # FloatTensor转为ndarray
    img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
    plt.imshow(img)
    plt.show()
    cnt = cnt+1
    if cnt == 10:
        break



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值