基于PyTorch的卷积神经网络图像分类——猫狗大战(一):使用Pytorch定义DataLoader

目录

1. 需要用到的库

2. 数据扩充定义

3. 自定义Dataset

4. 测试


         开始一个新的系列,基于Kaggle比赛的猫狗大战数据集,基于PyTorch实现猫狗图像分类。

         如何定义网络模型见:https://blog.youkuaiyun.com/linghu8812/article/details/119147899

         数据集地址在:https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/overview

         下面是第一部分,主要介绍如何使用Pytorch自定义Dataloader。

1. 需要用到的库

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

2. 数据扩充定义

image_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

数据扩充主要分为以下几步:

1)将图像的短边resize到256;

2)然后随即裁减224x224;

3)再进行随机水平翻转;

4)最后将图像转为Tensor并且标准化。

3. 自定义Dataset

class DogVsCatDataset(Dataset):
    """Dog vs Cat dataset."""

    def __init__(self, root_dir, train=True, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.img_path = os.listdir(self.root_dir)
        if train:
            self.img_path = list(filter(lambda x: int(x.split('.')[1]) < 10000, self.img_path))
        else:
            self.img_path = list(filter(lambda x: int(x.split('.')[1]) >= 10000, self.img_path))
        self.transform = transform

    def __len__(self):
        return len(self.img_path)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.root_dir, self.img_path[idx]))
        label = 0 if self.img_path[idx].split('.')[0] == 'cat' else 1
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array([label]))
        return image, label

数据集初始化时要设置图片目录;是否是训练集或者是验证集,图片编号小于10000的为训练集,大于等于10000的为验证集;及数据扩充方式;猫的标签为0,狗的标签为1。

4. 测试

if __name__ == '__main__':
    catanddog_dataset = DogVsCatDataset(root_dir='../dogs-vs-cats-redux-kernels-edition/train', train=False,
                                        transform=image_transform)
    train_loader = DataLoader(catanddog_dataset, batch_size=8, shuffle=True, num_workers=4)
    image, label = iter(train_loader).next()
    sample = image[0].squeeze()
    sample = sample.permute((1, 2, 0)).numpy()
    sample *= [0.229, 0.224, 0.225]
    sample += [0.485, 0.456, 0.406]
    plt.imshow(sample)
    plt.show()
    print('Label is: {}'.format(label[0].numpy()))

测试的时候使用“if __name__ == '__main__':”可以在其他文件import时,不执行这些语句。执行代码后,显示的图片和打印的标签如下所示:

Label is: [0]

Label is: [1]

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值