Pytorch框架猫狗大战之数据预处理

本文详细解析了PyTorch中DogCat数据集的预处理流程,包括数据集的加载、图片读取、数据增强及转换操作。针对训练、验证和测试集的不同需求,介绍了图像的标准化、尺寸调整、随机裁剪和翻转等处理技巧。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

数据预处理阶段,dataset.py中主要定义了DogCat类,并定义了__init__、__getitem__、__len__三个类方法,话不多说直接上代码,具体解读都在每一行后面的注释中。

# coding:utf8
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T


class DogCat(data.Dataset):

    def __init__(self, root, transforms=None, train=True, test=False):
        """
        主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
        """
        #os.listdir() 方法 : 返回指定文件夹包含的文件或文件夹名字的列表。该列表顺序以字母排序。
        #os.path.join() : 将多个路径组合后返回
        # 即imgs中存储的是路径+文件名的列表,如下
        # test1: data/test1/8973.jpg
        # train: data/train/cat.10004.jpg
        self.test = test
        imgs = [os.path.join(root, img) for img in os.listdir(root)]

        if self.test:
            imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('\\')[-1]))
        else:
            imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))

        imgs_num = len(imgs)

        #划分训练、验证集,验证:训练 = 3:7
        if self.test:
            self.imgs = imgs
        elif train:
            self.imgs = imgs[:int(0.7 * imgs_num)]
        else:
            self.imgs = imgs[int(0.7 * imgs_num):]

        if transforms is None:
            #对图片数据进行转换操作,测试、验证和训练的数据转换有所区别
            #标准化,即减均值,除以标准差,1*3对应RGB三通道,这组数据是imagenet训练集中抽样算出来的
            normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

            #测试集和验证集
            if self.test or not train:
                self.transforms = T.Compose([
                    T.Resize(224),      #缩放Image,长宽比不变,最短边为224像素
                    T.CenterCrop(224),  #从图片中间切出224*224的图片
                    T.ToTensor(),       #将图片转成Tensor并归一化至[0,1]
                    normalize           #标准化
                ])
            #训练集
            else:
                self.transforms = T.Compose([
                    T.Resize(256),
                    T.RandomResizedCrop(224),   #先随机采集,然后对裁剪得到的图像缩放为同一大小
                    T.RandomHorizontalFlip(),   #以给定的概率p随机水平旋转PIL的图像,p默认值为0.5
                    T.ToTensor(),
                    normalize
                ])

    def __getitem__(self, index):
        """
        一次返回一张图片的数据
        如果是测试集,没有图片id,如1000.jpg返回1000
        """
        img_path = self.imgs[index]     #第index张图片的路径+文件名
        if self.test:
            label = int(self.imgs[index].split('.')[-2].split('\\')[-1])
        else:
            label = 1 if 'dog' in img_path.split('\\')[-1] else 0
        data = Image.open(img_path)
        data = self.transforms(data)        #调用init中的transform处理图片
        return data, label

    def __len__(self):
        '''
        返回数据集中所有图片的个数
        '''
        return len(self.imgs)

附上我看代码时参考的别人的博客:

pytorch torchvision.transforms.Normalize()中的mean和std参数---解惑

transforms的二十二个方法

pytorch中T.RandomResizedCrop()等图像操作转换的效果

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值