图像二分类数据集划分和预处理以及批处理pytorch代码实现

本文介绍了如何在PyTorch中处理图像二分类问题,包括理解epoch、iteration和batch_size的关系,数据集的划分,自定义Dataset类,以及使用transforms进行图像增强和Dataloader进行批处理。示例中展示了未处理数据的存储结构,并提供了划分训练集、验证集和测试集的代码,以及实现Dataset类的关键方法。

在进行图像的数据处理之前我们先搞懂epoch、iteration和batch_size三者的关系

epoch:所有训练样本都送入到模型中称为一个eopch
iteration:一批样本输入到模型中,称之为一个iteration
batch_size:批大小,决定了一个epoch有多少个iteration
假设样本总数80 batch_size:8
则1epoch = 10iteration

下面先简单讲解一下思路
1.首先将图片进行训练集、验证集和测试集划分
2.接下来因为是使用pytorch的Dataloader类进行数据处理,因此我们要实现Dataset类,这个类需要自己实现
3.写完Dataset类可以用它来构建Dataloader类,用Dataloader类进行数据批处理,在进行数据批处理之前可以使用transforms对图像进行图像增强

首先先介绍未处理数据的文件存储结构.hello pytorch为项目的根路径
图像标签为1的存储路径:hello pytorch/data/RMB_data/1 #文件夹1下存放的是标签为1的图片 数量100张
图像标签为100的存储路径:hello pytorch/data/RMB_data/100 # 文件夹100下存放的是标签为100的图片 数量100张

在这里插入图片描述

划分训练集、验证集和测试集代码

import os
import random
import shutil

def makedir(new_dir): # 如果new_dir文件夹不存在则创造此文件夹
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

dataset_dir = os.path.join("..", "..", "data", "RMB_data")  # F:\pythonProject\hello pytorch\data\RMB_data
split_dir = os.path.join("..", "..", "data", "rmb_split") # 为划分的数据集准备的文件夹,此时还没有创建
train_dir = os.path.join(split_dir, "train") # 为训练集准备的文件夹,此时还没有创建
valid_dir = os.path.join(split_dir, "valid") # 为验证集创建的文件夹,此时还没有创建
test_dir = os.path.join(split_dir, "test") # 为测试集创造的文件夹,此时还没有创建

# 为测试集、训练集和验证集划分比例
train_pct = 0.8
valid_pct = 0.1
test_pct = 0.1

for root, dirs, files in os.walk(dataset_dir):

        for sub_dir in dirs:  # 1
            imgs = os.listdir(os.path.join(root, sub_dir)
<think>好的,用户想了解在PyTorch中如何进行数据集划分数据预处理。首先,我需要回忆相关的知识,并结合提供的引用资料来组织回答。 首先,数据集划分通常涉及将数据分为训练集、验证集测试集。PyTorch中常用的方法是使用`random_split`函数。我需要确认这个函数的具体用法,比如是否需要先创建一个数据集对象,然后按比例划分。记得引用中提到的`Dataset``DataLoader`类,这些是关键点。 接下来是数据预处理部分。根据引用[2][4],PyTorch主要通过`torchvision.transforms`模块来处理数据,比如标准化、数据增强等。需要提到常见的预处理操作,如调整大小、转换为张量、归一化等。可能还需要提到自定义数据集的方法,如继承`Dataset`类并实现`__getitem__``__len__`方法。 然后,数据加载部分,使用`DataLoader`来批量加载数据,设置参数如`batch_size`、`shuffle`等。这里需要强调`DataLoader`的作用,如包装数据集为可迭代对象,方便批处理。 还需要考虑用户可能的实际应用场景,比如图像分类任务。举个例子,使用CIFAR10数据集,应用数据增强方法,如随机裁剪、水平翻转等。同时,验证集测试集可能不需要数据增强,只做基础预处理。 另外,用户可能关心如何确保划分后的数据集没有重叠,特别是当数据集本身没有预分割的情况下。这时候`random_split`的参数设置很重要,如指定长度生成器种子以确保可重复性。 可能还需要提到自定义数据集的步骤,比如处理不同的数据格式或结构,如何结合`transforms`进行处理。引用[3]提到了数据加载工具类,所以需要强调`DataLoader`的灵活性效率。 最后,总结步骤,确保回答结构清晰:数据集划分预处理→加载。同时,生成相关问题,帮助用户进一步探索,比如性能优化、高级预处理技巧、自定义数据集方法等。</think>在PyTorch中,数据集划分与数据预处理是模型训练的重要前置步骤,具体实现方法如下: ### 一、数据集划分方法 PyTorch提供了`torch.utils.data.random_split`函数实现随机划分: ```python from torch.utils.data import Dataset, random_split # 假设已定义完整数据集full_dataset train_size = int(0.8 * len(full_dataset)) val_size = int(0.1 * len(full_dataset)) test_size = len(full_dataset) - train_size - val_size train_dataset, val_dataset, test_dataset = random_split( full_dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42) # 确保可复现性 ) ``` 该方法通过设定比例随机划分数据集,需注意划分后的数据集不会重叠[^3]。 ### 二、数据预处理技巧 PyTorch通过`torchvision.transforms`模块实现数据预处理: ```python from torchvision import transforms # 训练集预处理(含数据增强) train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), # 随机裁剪缩放 transforms.RandomHorizontalFlip(), # 水平翻转增强 transforms.ToTensor(), # 转为张量 transforms.Normalize(mean=[0.485, 0.456, 0.406], # 归一化 std=[0.229, 0.224, 0.225]) ]) # 验证/测试集预处理(无增强) eval_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) ``` 数据增强可有效提升模型泛化能力,但验证/测试集应保持原始数据分布[^4]。 ### 三、数据加载实现 使用`DataLoader`进行批量加载: ```python from torch.utils.data import DataLoader train_loader = DataLoader( train_dataset, batch_size=64, shuffle=True, # 训练集需要打乱顺序 num_workers=4 # 并行加载进程数 ) val_loader = DataLoader( val_dataset, batch_size=32, shuffle=False ) ``` `DataLoader`支持多进程加载、自动批处理等功能,提升数据加载效率[^3]。 ### 四、自定义数据集示例 对于非标准数据格式,需继承`Dataset`类: ```python class CustomDataset(Dataset): def __init__(self, data_path, transform=None): self.data = load_custom_data(data_path) self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] if self.transform: sample = self.transform(sample) return sample ``` 通过实现`__getitem__``__len__`方法,可灵活处理任意格式数据[^2]。
评论 3
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值