Pytorch基于卷积神经网络的猫狗识别

实验环境

  1. Pytorch 1.4.0
  2. conda 4.7.12
  3. Jupyter Notebook 6.0.1
  4. Python 3.7

数据集介绍

实验采用的猫和狗的图片来自 Kaggle 竞赛的一个赛题 Cat vs Dog 的数据集,其中训练数据集包 括 25000 张图片,其中类别为猫的图片有 12500 张图片,类别为狗的图片有 12500 张,两种类别比例为 1:1。训练集有 25000 张,猫狗各占一半。测试集 12500 张,猫狗各占一半。实际上该数据集是 Asirra 数据集的子集。

Asirra 数据集的来源:
Web 服务有时通过行为验证信息来保护自身不被网络攻击,因为类似识别一个物品这样的问题对 人们来说很容易解决,但对计算机却很难。这种挑战通常称为 CAPTCHA 完全自动化的公共 Turing 测试,以区分计算机和人类)或 HIP(人类互动证明)。HIP 有多种用途,例如减少电子邮件和博客垃 圾邮件,以及防止对网站密码的暴力攻击。Asirra(用于限制访问的动物物种图像识别)是一种 HIP, 其工作原理是要求用户识别猫和狗的照片。对于计算机而言,此任务很困难,但研究表明,人们可以快 速而准确地完成此任务。Asirra 之所以与众不同,是因为它与 Petfinder.com 合作,Petfinder.com 是全 球最大的致力于寻找无家可归宠物的家的网站。他们为 Microsoft Research 提供了超过三百万张猫和 狗的图像,这些图像由美国数千家动物收容所中的人手动分类。Kaggle 很幸运能够提供这些数据的子 集,以供娱乐和研究之用。

训练过程

数据准备

数据预处理:首先,导入实验所需的库,定义一些宏参数,BATCH_SIZE 表示每个 batch 加载多 少个样本、EPOCHS 表示总共训练批次。如果支持 cuda 就用 gpu 来 run,不支持就用 cpu 来 run。

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import torch
import torch.nn as nn
import cv2
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms,models
from torch.optim.lr_scheduler import *
import copy
import random
import tqdm
from PIL import Image
import torch.nn.functional as F

%matplotlib inline

BATCH_SIZE = 20
EPOCHS = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

从 Kaggle 官网下载好数据集 train.zip 和 test1.zip,解压到项目目录 data 文件夹下,重命名训练 集和测试集文件夹名字。由于 listdir 参数不允许有”…” 和”.”,所以我先获取项目路径,再拼接上项目目 录下训练集和测试集的位置,构成训练集和测试集的路径地址,最后通过 listdir 获取相应目录下文件 名的集合。

cPath = os.getcwd()
train_dir = cPath + '/data/train'
test_dir = cPath + '/data/test'
train_files = os.listdir(train_dir)
test_files = os.listdir(test_dir)

训练集的图片命名规则是:类型. 序号.jpg,我定义一个数据集处理类 CatDogDataset 来对数据集 进行预处理,狗的 label 为 1,猫的 label 为 0。以及在 getitem 时调用 transform 处理输入数据,根据 mode 返回不同的信息,mode=train 则返回训练图片和标签,其他则返回图片和图片文件名。

class CatDogDataset(Dataset):
    def __init__(self, file_list, dir, mode='train', transform = None):
        self.file_list = file_list
        self.dir = dir
        self.mode= mode
        self.transform = transform
        if self.mode == 'train':
            if 'dog' in self.file_list[0]:
                self.label = 1
            else:
                self.label = 0
            
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        img = Image.open(os.path.join(self.dir, self.file_list[idx]))
        if self.transform:
            img = self.transform(img)
        if self.mode == 'train':
            img = img.numpy()
            return img.astype('float32'), self.label
        else:
            img = img.numpy()
            return img.astype('float32'), self.file_list[idx]

使用自定义的 transform 进行数据增强,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛 化能力,以及数据处理统一输入图片格式大小和归一化。train_transforms 先调整图片大小至 256x256 重置图像分辨率,再按照 224x224 随机剪裁,然后随机的图像水平翻转,转化成 tensor,最后采用 ImageNet 给出的数值归一化。接着构造 train dataloader,目的是为了方便读取和使用,设置 batch 大 小,采用多线程,shuffle=True 设置在每个 epoch 重新打乱数据,保证数据的随机性。
test_transform 重置图片分辨率 224x224,转化成 tensor,同样采用 ImageNet 给出的数值归一化。 接着构造 test dataloader,设置 batch size,采用多线程,shuffle=False。

train_transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 先调整图片大小至256x256
    transforms.RandomCrop((224, 224)),  # 再随机裁剪到224x224
    transforms.RandomHorizontalFlip(),  # 随机的图像水平翻转,通俗讲就是图像的左右对调
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))  # 归一化,数值是用ImageNet给出的数值
])


cat_files = [tf for tf in train_files if 'cat' in tf]
dog_files = [tf for tf in train_files if 'dog' in tf]

cats = CatDogDataset(cat_files, train_dir, transform = train_transform)
dogs = CatDogDataset(dog_files, train_dir, transform = train_transform)

train_set = ConcatDataset([cats, dogs])
train_loader = DataLoader(train_set, batch_size = BATCH_SIZE, shuffle=True, num_workers=0)

test_transform = transforms.Compose([
    transforms.
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值