pytorch 读取数据

本文详细介绍如何使用PyTorch的ImageFolder加载图像数据集,包括数据预处理、批处理和加载过程。通过实例演示了如何定义数据转换,如灰度转换、调整大小、张量化及归一化,并展示如何创建DataLoader进行数据批量加载。
部署运行你感兴趣的模型镜像

pytorch Dataset 的ImageFolder
ImageFolder例子

def load_data(root_dir,domain,batch_size):
    transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize([28, 28]),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0,0,0),std=(1,1,1)),
    ]
    )
    image_folder = datasets.ImageFolder(
            root=root_dir + domain,
            transform=transform
        )
    data_loader = torch.utils.data.DataLoader(dataset=image_folder,batch_size=batch_size,shuffle=True,num_workers=2,drop_last=True
    )
    return data_loader
    
data_src = data_loader.load_data(
        root_dir=rootdir, domain='amazon', batch_size=BATCH_SIZE[0])
        
for e in tqdm(range(1, N_EPOCH + 1)):
        model = train(model=model, optimizer=optimizer,
                      epoch=e, data_src=data_src, data_tar=data_tar)

【PyTorch学习笔记】14:划分训练-验证-测试集,使用正则化项
torch.utils.data.random_split源码

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

PyTorch读取数据集的方式一般有三种情况,具体如下: 1. 读取官方给的数据集,例如Imagenet,CIFAR10,MNIST等。这些库调用`torchvision.datasets.XXXX()`即可,例如想要读取MNIST数据集: ```python import torch import torch.nn as nn import torch.utils.data as Data import torchvision train_data = torchvision.datasets.MNIST( root='./mnist/', train=True, # this is training data transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0] download=True, ) ``` 2. 使用`torch.utils.data.Dataset`和`torch.utils.data.DataLoader`自定义数据集读取方式。这种方式需要自己定义数据集的读取方式,可以适用于各种数据集,例如图像、文本、音频等。具体实现可以参考以下代码: ```python class MyDataset(Data.Dataset): def __init__(self, data, label): self.data = data self.label = label def __getitem__(self, index): x = self.data[index] y = self.label[index] return x, y def __len__(self): return len(self.data) train_data = MyDataset(data, label) train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) ``` 3. 直接读取数据文件,例如txt、csv等。这种方式需要自己编写读取文件的代码,可以适用于各种格式的数据文件。具体实现可以参考以下代码: ```python class MyDataset(Data.Dataset): def __init__(self, file_path): self.data = [] self.label = [] with open(file_path, 'r') as f: for line in f: line = line.strip().split(',') self.data.append(line[:-1]) self.label.append(line[-1]) def __getitem__(self, index): x = self.data[index] y = self.label[index] return x, y def __len__(self): return len(self.data) train_data = MyDataset(file_path) train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) ```
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值