【pytorch】自定义读取数据集,使用txt文本

本文介绍了一种自定义加载图像数据集的方法,通过Python代码实现图像及其标签的读取,适用于有监督学习任务。该方法能有效减少内存消耗,适用于不同来源的图像数据。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

使用txt文本读入数据可以减少内存的需要,有时候自定义加载数据集是非常必要的,我下面的代码是针对图像的,并且带有label的有监督的图像。先看代码:

import numpy as np
import os
import torch.nn as nn
from PIL import Image


def default_loader(path):
    return Image.open(path).convert('RGB')


class MyDataSet(nn.Module):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        super(MyDataSet, self).__init__()
        fh = open(txt, 'r')
        imgs = []
        target = []
        for line in fh:
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0]))
            target.append(words[1])
        self.imgs = imgs
        self.target = target
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        fh.close()

    def __getitem__(self, index):
        fn= self.imgs[index]
        target=self.target[index]
        img = self.loader(fn)
        target=self.loader(target)
        img=np.asarray(img)
        target=np.asarray(target)
        # img = Image.fromarray(np.array(img), mode='L')
        # target = Image.fromarray(np.array(target), mode='L')

        if self.transform is not None:
            img = self.transform(img)
            target=self.transform(target)
        return img, target

    def __len__(self):
        return len(self.imgs)


def TxtData(canon_path, iphone_path, txt_path):
    txt_name = 'DPED.txt'
    txt_path = os.path.join(txt_path, txt_name)
    canon = os.listdir(canon_path)
    iphone = os.listdir(iphone_path)
    file = open(txt_path, 'w')

    if len(canon) != len(iphone):
        print('the number of traing data varies between the two files')

    for i in range(len(canon)):
        path = os.path.join(iphone_path, str(i) + '.jpg') + ' ' + os.path.join(canon_path, str(i) + '.jpg') + '\n'
        file.write(path)


if __name__ == '__main__':
    data_canon = './data/iphone/training_data/canon'
    data_iphone = './data/iphone/training_data/iphone'
    txt_path = './data'
    TxtData(data_canon, data_iphone, txt_path)

上面代码主要分两部分:

1:图像对分别在两个不同的文件夹,文件路径可自己更改,先运行代码会生成一个.txt文件,里面包括了所以图像的文件路径。txt文如下:

2:定义自己的加载数据的子类。每当使用数据时,也就是数据索引[ ]时会自己调用 def  __getitem__函数。这部分是在训练阶段 才会进行。

<think>首先,用户的问题是基于提供的引用内容:"请问如何在PyTorch中实现自定义数据集加载类"。引用内容讨论了在PyTorch中创建自定义数据集加载器,特别提到了继承`torch.utils.data.Dataset`类。 关键点从引用中: - 引用[1]:PyTorch 提供了`torch.utils.data.DataLoader`和`torch.utils.data.Dataset`来使用自定义数据。 - 引用[2]:自定义数据集通过继承`Dataset`类实现,需要定义三个魔法方法:`__init__`、`__len__`、`__getitem__`。 - 引用[3]:详细说明了如何实现子类,包括初始化、长度和获取项目。 - 引用[4]:提到了创建自定义数据集类并生成标签文件,提供了一个数据集下载链接。 系统级指令: - 所有行内数学表达式用`$...$`格式。 - 独立公式用`$$...$$`格式并单独成段。 - LaTeX语法正确。 - 使用中文回答。 - 生成相关问题。 - 回答中引用的段落末尾自然地添加引用标识,例如[^1]。 用户可见层: - 回答结构清晰,帮助用户逐步解决问题。 - 保证回答真实可靠。 用户的问题:如何在PyTorch中实现自定义数据集加载类?这应该包括创建自定义数据集类和使用数据加载器。 结构回答: 1. **介绍**:简要解释PyTorch自定义数据集的重要性。 2. **步骤**:逐步说明如何实现自定义数据集类。 - 继承`Dataset`类。 - 实现`__init__`方法:初始化数据路径、标签等。 - 实现`__len__`方法:返回数据集大小。 - 实现`__getitem__`方法:加载和预处理单个样本。 3. **使用数据加载器**:展示如何将自定义数据集传递给`DataLoader`。 4. **示例代码**:提供一个简单的代码示例。 5. **引用**:在相关部分末尾添加引用标识。 6. **生成相关问题**:在回答末尾添加相关问题。 确保使用中文回答。 生成相关问题: - 基于主题,生成3-4个相关问题,如优化、性能分析等。 示例代码:基于引用,提供一个自定义数据集类的示例。例如,使用一个图像数据集,如蜜蜂蚂蚁数据集(从引用[4])。 完整回答: - 开头:直接回答问题。 - 主体:逐步解释。 - 结尾:总结并添加相关问题。 引用标识:在回答中,当提到引用内容时,在段落末尾添加[^n]。例如,引用[2]讨论了三个魔法方法。 最终输出格式: - 回答部分。 - §§相关问题§§:列出相关问题。</think>### 在PyTorch中实现自定义数据集加载类 在PyTorch中,自定义数据集加载类允许用户处理非标准数据(如图像、文本或时间序列),通过继承`torch.utils.data.Dataset`类并实现关键方法来实现。这确保了数据能被高效加载和预处理,再通过`DataLoader`进行批处理和迭代。以下是基于PyTorch最佳实践的逐步实现指南,结合了引用内容的核心要点[^1][^2][^3]。 #### 步骤1: 理解自定义数据集类的结构 自定义数据集类必须继承`torch.utils.data.Dataset`,并覆盖三个核心方法: - `__init__(self, ...)`:初始化数据集参数(如数据路径、标签文件或预处理函数)。 - `__len__(self)`:返回数据集样本总数。 - `__getitem__(self, index)`:根据索引加载单个样本,并进行预处理(如归一化或数据增强)。 这些方法确保了数据集的兼容性,使其能被`DataLoader`无缝成[^2][^3]。 #### 步骤2: 实现自定义数据集类 以下是一个完整示例,使用图像数据集(如蜜蜂蚂蚁数据集[^4])作为场景。假设数据集包含图像文件和对应的标签CSV文件。 ```python import torch from torch.utils.data import Dataset from PIL import Image import pandas as pd import os class CustomImageDataset(Dataset): def __init__(self, img_dir, label_file, transform=None): """ 初始化数据集。 :param img_dir: 图像文件夹路径 :param label_file: 标签文件路径(CSV格式) :param transform: 数据预处理函数(如torchvision.transforms) """ super(CustomImageDataset, self).__init__() self.img_dir = img_dir self.labels = pd.read_csv(label_file) # 读取标签文件 self.transform = transform # 可选的数据增强 def __len__(self): """返回数据集大小。""" return len(self.labels) # 确保标签文件行数等于样本数[^3] def __getitem__(self, idx): """加载并预处理单个样本。""" img_path = os.path.join(self.img_dir, self.labels.iloc[idx, 0]) # 第一列为图像文件 image = Image.open(img_path).convert('RGB') # 加载图像 label = self.labels.iloc[idx, 1] # 第二列为标签值 if self.transform: image = self.transform(image) # 应用预处理(如Resize, ToTensor) return image, label # 返回样本和标签 ``` **关键点说明**: - **初始化 (`__init__`)**:设置数据路径和预处理函数。标签文件通常为CSV,格式如`image_name, label`[^4]。 - **长度 (`__len__`)**:返回样本数,确保索引范围有效。 - **获取项目 (`__getitem__`)**:加载图像并应用预处理(如使用`torchvision.transforms`),输出张量格式的样本和标签[^2][^3]。 - **预处理函数**:可使用PyTorch的`transforms`模块,例如: ```python from torchvision import transforms transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), # 转换为张量并归一化到[0,1] transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化 ]) ``` #### 步骤3: 使用DataLoader加载数据集 创建自定义数据集实例后,传递给`DataLoader`以实现批处理、洗牌和多线程加载。 ```python from torch.utils.data import DataLoader # 创建数据集实例 dataset = CustomImageDataset( img_dir='path/to/images', label_file='path/to/labels.csv', transform=transform ) # 创建DataLoader dataloader = DataLoader( dataset, batch_size=32, # 批大小 shuffle=True, # 训练时洗牌数据 num_workers=4 # 多线程加载,加速数据读取 ) # 示例:迭代数据集 for images, labels in dataloader: # 在此处进行模型训练 pass ``` **优点**: - **高效加载**:`DataLoader`支持并行数据加载,减少I/O瓶颈。 - **灵活性**:可处理各种数据类型,包括非图像数据(如修改`__getitem__`加载文本或音频)。 - **兼容性**:与PyTorch模型无缝成[^1][^3]。 #### 常见问题与优化 - **标签文件生成**:如果无现成标签,可编写脚本生成CSV文件(如遍历文件夹并分配标签)[^4]。 - **性能优化**:使用`num_workers`加速加载,或使用内存映射技术处理大文件。 - **错误处理**:在`__getitem__`中添加异常捕获(如文件缺失时跳过样本)。 通过以上步骤,您可以在PyTorch中高效实现自定义数据集加载类,适用于各种机器学习任务。自定义数据集的核心是确保数据格式一致性和预处理标准化,这能提升模型训练效果[^1][^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值