pytorch和numpy保存和加载数据集的三种方法

已CIFAR10数据集为例子,下面重点介绍下三种不同的保存和加载CIFAR10数据集的方法,所有的下载CIFAR10代码如下。

  • 用torch.save保存和torch.load加载
  • 用numpy的savez保存和numpy的load数据集,已经torch.tensor和DataLoader将其转成pytorch的目标数据集。这个方法有两种策略,一个是保存原始图片的data,一个是保存归一化后的data。具体方法的完整代码如下。

1. torch.save/torch.load保存完整CIFAR10对象

import os
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

data_dir = './data'  #定义数据存储的目录
train_path = os.path.join(data_dir, 'trainset.pt')  #训练集保存路径
test_path = os.path.join(data_dir, 'testset.pt')  #测试集保存路径

transform = transforms.ToTensor()  #定义图像预处理操作:将图像转换为张量

#如果训练集和测试集的文件已经存在
if os.path.exists(train_path) and os.path.exists(test_path):
    print("Loading dataset from torch saved files...")  #打印提示信息
    trainset = torch.load(train_path)  #从文件中加载训练集
    testset = torch.load(test_path)  #从文件中加载测试集
else:
    print("Downloading CIFAR10 dataset and saving with torch.save...")  #打印提示信息
    #下载 CIFAR10 训练集并应用预处理
    trainset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
    #下载 CIFAR10 测试集并应用预处理
    testset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)

    #将训练集保存为 .pt 文件
    torch.save(trainset, train_path)
    #将测试集保存为 .pt 文件
    torch.save(testset, test_path)
#创建训练集的数据加载器,批大小为64,打乱数据顺序
train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
#创建测试集的数据加载器,批大小为64,不打乱数据顺序
test_loader = DataLoader(testset, batch_size=64, shuffle=False)
#打印训练集和测试集的样本数量
print(f"Train samples: {len(trainset)}, Test samples: {len(testset)}")

2.np.savez保存经过transform的tensor

import os
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

data_dir = './data'  #定义数据存储的目录
train_npz_path = os.path.join(data_dir, 'trainset_processed.npz')  #训练集保存路径(.npz 格式)
test_npz_path = os.path.join(data_dir, 'testset_processed.npz')  #测试集保存路径(.npz 格式)
transform = transforms.ToTensor()  #定义图像预处理操作:将图像转换为张量

#定义函数:将数据集保存为 .npz 格式
def save_processed_dataset(dataset, path):
    data_list = []  #用于存储图像数据
    target_list = []  #用于存储标签
    for i in range(len(dataset)):  #遍历数据集中的每个样本
        img, label = dataset[i]  #获取图像和标签
        data_list.append(img.numpy())  #将图像张量转换为 NumPy 数组并添加到列表中
        target_list.append(label)  #添加标签
    np.savez(path, data=np.stack(data_list), targets=np.array(target_list))  #保存为 .npz 文件
    print(f"Saved processed dataset to {path}")  #打印保存路径

#定义函数:从 .npz 文件中加载数据集
def load_processed_dataset(path):
    loaded = np.load(path)  #加载 .npz 文件
    data = torch.tensor(loaded['data'], dtype=torch.float32)  #将图像数据转换为 float32 类型的张量
    targets = torch.tensor(loaded['targets'], dtype=torch.long)  #将标签转换为 long 类型的张量
    return TensorDataset(data, targets)  #返回一个张量数据集对象

#如果处理后的数据集文件存在,则直接加载
if os.path.exists(train_npz_path) and os.path.exists(test_npz_path):
    print("Loading processed dataset from .npz files...")  #打印提示信息
    trainset = load_processed_dataset(train_npz_path)  #加载训练集
    testset = load_processed_dataset(test_npz_path)  #加载测试集
else:
    print("Downloading CIFAR10 dataset and saving processed tensors to .npz...")  #打印提示信息
    #下载原始 CIFAR10 训练集并应用预处理
    train_original = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
    #下载原始 CIFAR10 测试集并应用预处理
    test_original = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)

    #保存处理后的训练集和测试集为 .npz 文件
    save_processed_dataset(train_original, train_npz_path)
    save_processed_dataset(test_original, test_npz_path)

    #加载保存后的数据集
    trainset = load_processed_dataset(train_npz_path)
    testset = load_processed_dataset(test_npz_path)

#创建训练集的数据加载器,批大小为64,打乱数据顺序
train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
#创建测试集的数据加载器,批大小为64,不打乱数据顺序
test_loader = DataLoader(testset, batch_size=64, shuffle=False)
#打印训练集和测试集的样本数量
print(f"Train samples: {len(trainset)}, Test samples: {len(testset)}")

3. np.savez保存原始.data和.targets

import os
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

data_dir = './data'  #定义数据存储目录
train_npz_path = os.path.join(data_dir, 'trainset_raw.npz')  #训练集保存路径(原始格式)
test_npz_path = os.path.join(data_dir, 'testset_raw.npz')  #测试集保存路径(原始格式)

#定义函数:保存原始数据集为.npz文件
#注意:不使用 transform,保留原始uint8图像数据
def save_raw_dataset(dataset, path):
    np.savez(path, data=dataset.data, targets=np.array(dataset.targets))  #保存图像数据和标签
    print(f"Saved raw dataset to {path}")  #打印保存路径

#定义函数:从.npz文件中加载原始数据集
def load_raw_dataset(path):
    loaded = np.load(path)  #加载.npz文件
    data = torch.tensor(loaded['data'], dtype=torch.float32)  #将uint8图像数据转换为float32类型张量
    data = data.permute(0, 3, 1, 2)  #通道维度转换:从 [N, H, W, C] 转为 [N, C, H, W]
    data /= 255.0  #将像素值归一化到 [0, 1]
    targets = torch.tensor(loaded['targets'], dtype=torch.long)  #将标签转换为 long 类型张量
    return TensorDataset(data, targets)  #返回张量数据集对象

#如果原始数据集文件已存在,则直接加载
if os.path.exists(train_npz_path) and os.path.exists(test_npz_path):
    print("Loading raw dataset from.npzfiles...")  #打印提示信息
    trainset = load_raw_dataset(train_npz_path)  #加载训练集
    testset = load_raw_dataset(test_npz_path)  #加载测试集
else:
    print("Downloading CIFAR10 dataset and saving raw data to .npz...")  #打印提示信息
    #下载原始 CIFAR10 数据集(不使用 transform,保留原始格式)
    train_original = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=None)
    test_original = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=None)

    #保存原始训练集和测试集为.npz文件
    save_raw_dataset(train_original, train_npz_path)
    save_raw_dataset(test_original, test_npz_path)

    #加载保存后的数据集
    trainset = load_raw_dataset(train_npz_path)
    testset = load_raw_dataset(test_npz_path)

#创建训练集的数据加载器,批大小为64,打乱数据顺序
train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
#创建测试集的数据加载器,批大小为64,不打乱数据顺序
test_loader = DataLoader(testset, batch_size=64, shuffle=False)
#打印训练集和测试集的样本数量
print(f"Train samples: {len(trainset)}, Test samples: {len(testset)}")

方法一:使用torch.save/torch.load
这是最直接、最符合PyTorch用户习惯的做法。我们将经过transform处理后的CIFAR10数据集对象直接使用torch.save()序列化为.pt文件,之后使用torch.load()还原为完整对象。
这种方式的优点是简单快捷,代码量少,保存和加载速度都非常快;同时能保留PyTorch对象原始结构,直接配合DataLoader使用即可开始训练,无需任何额外处理。但它的缺点也明显,由于使用Python的pickle序列化,跨平台兼容性较差;如果数据集对象包含了复杂的transform或定制类,可能在不同版本PyTorch或其他语言环境中无法加载。
因此,这种方法适合自己训练模型、本地复现实验或快速缓存数据集,但不适合共享或长期存档。

方法二:使用numpy.savez保存transform后的张量(trainset[i][0])
该方法遍历整个数据集,依次取出transform后的图像张量(通常为[C,H,W]且已经归一化到[0,1]),和其对应标签,统一保存到.npz文件中。这种方式保存的数据格式已经符合训练要求,加载后可直接转换为TensorDataset使用,无需做通道转置、归一化等处理。
它的优点在于兼容性强(保存为标准NumPy格式),加载后无须复杂预处理即可用于训练。相比方法一,它更适合于数据共享、跨平台模型部署等使用场景。但因为要遍历整个数据集,转换为NumPy,再存盘,所以保存过程比方法一稍慢;同时由于是保存归一化后的float32张量,文件体积也比原始数据略大。
整体而言,方法二是一个兼顾实用性和通用性的方法,推荐在需要长期保存数据或和他人共享实验数据时使用。

方法三:使用numpy.savez保存原始.data和.targets
这种方式不依赖transform,而是直接读取CIFAR10数据集对象的.data属性(即原始的RGB图像,[H,W,C],uint8类型)和.targets(整数列表),保存为.npz文件。相比方法二,它速度更快,占用空间更小(因为保存的是uint8图像而非float32),但缺点是加载后不能直接用于训练。
在加载过程中,需要手动对数据进行以下处理:首先将图像张量从[N,H,W,C]转换为[N,C,H,W];然后将像素值从整数类型归一化为[0,1]的浮点数。这使得代码相对复杂一些,但也给予了更大的灵活性。
这种方法适合对数据预处理流程有特殊要求的用户,或希望将数据提供给非PyTorch平台(如TensorFlow、Keras、ONNX等)使用,或出于存储/发布原因希望保留原始图像的原貌。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值