已CIFAR10数据集为例子,下面重点介绍下三种不同的保存和加载CIFAR10数据集的方法,所有的下载CIFAR10代码如下。
- 用torch.save保存和torch.load加载
- 用numpy的savez保存和numpy的load数据集,已经torch.tensor和DataLoader将其转成pytorch的目标数据集。这个方法有两种策略,一个是保存原始图片的data,一个是保存归一化后的data。具体方法的完整代码如下。
1. torch.save/torch.load保存完整CIFAR10对象
import os
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
data_dir = './data' #定义数据存储的目录
train_path = os.path.join(data_dir, 'trainset.pt') #训练集保存路径
test_path = os.path.join(data_dir, 'testset.pt') #测试集保存路径
transform = transforms.ToTensor() #定义图像预处理操作:将图像转换为张量
#如果训练集和测试集的文件已经存在
if os.path.exists(train_path) and os.path.exists(test_path):
print("Loading dataset from torch saved files...") #打印提示信息
trainset = torch.load(train_path) #从文件中加载训练集
testset = torch.load(test_path) #从文件中加载测试集
else:
print("Downloading CIFAR10 dataset and saving with torch.save...") #打印提示信息
#下载 CIFAR10 训练集并应用预处理
trainset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
#下载 CIFAR10 测试集并应用预处理
testset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)
#将训练集保存为 .pt 文件
torch.save(trainset, train_path)
#将测试集保存为 .pt 文件
torch.save(testset, test_path)
#创建训练集的数据加载器,批大小为64,打乱数据顺序
train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
#创建测试集的数据加载器,批大小为64,不打乱数据顺序
test_loader = DataLoader(testset, batch_size=64, shuffle=False)
#打印训练集和测试集的样本数量
print(f"Train samples: {len(trainset)}, Test samples: {len(testset)}")
2.np.savez保存经过transform的tensor
import os
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
data_dir = './data' #定义数据存储的目录
train_npz_path = os.path.join(data_dir, 'trainset_processed.npz') #训练集保存路径(.npz 格式)
test_npz_path = os.path.join(data_dir, 'testset_processed.npz') #测试集保存路径(.npz 格式)
transform = transforms.ToTensor() #定义图像预处理操作:将图像转换为张量
#定义函数:将数据集保存为 .npz 格式
def save_processed_dataset(dataset, path):
data_list = [] #用于存储图像数据
target_list = [] #用于存储标签
for i in range(len(dataset)): #遍历数据集中的每个样本
img, label = dataset[i] #获取图像和标签
data_list.append(img.numpy()) #将图像张量转换为 NumPy 数组并添加到列表中
target_list.append(label) #添加标签
np.savez(path, data=np.stack(data_list), targets=np.array(target_list)) #保存为 .npz 文件
print(f"Saved processed dataset to {path}") #打印保存路径
#定义函数:从 .npz 文件中加载数据集
def load_processed_dataset(path):
loaded = np.load(path) #加载 .npz 文件
data = torch.tensor(loaded['data'], dtype=torch.float32) #将图像数据转换为 float32 类型的张量
targets = torch.tensor(loaded['targets'], dtype=torch.long) #将标签转换为 long 类型的张量
return TensorDataset(data, targets) #返回一个张量数据集对象
#如果处理后的数据集文件存在,则直接加载
if os.path.exists(train_npz_path) and os.path.exists(test_npz_path):
print("Loading processed dataset from .npz files...") #打印提示信息
trainset = load_processed_dataset(train_npz_path) #加载训练集
testset = load_processed_dataset(test_npz_path) #加载测试集
else:
print("Downloading CIFAR10 dataset and saving processed tensors to .npz...") #打印提示信息
#下载原始 CIFAR10 训练集并应用预处理
train_original = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
#下载原始 CIFAR10 测试集并应用预处理
test_original = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)
#保存处理后的训练集和测试集为 .npz 文件
save_processed_dataset(train_original, train_npz_path)
save_processed_dataset(test_original, test_npz_path)
#加载保存后的数据集
trainset = load_processed_dataset(train_npz_path)
testset = load_processed_dataset(test_npz_path)
#创建训练集的数据加载器,批大小为64,打乱数据顺序
train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
#创建测试集的数据加载器,批大小为64,不打乱数据顺序
test_loader = DataLoader(testset, batch_size=64, shuffle=False)
#打印训练集和测试集的样本数量
print(f"Train samples: {len(trainset)}, Test samples: {len(testset)}")
3. np.savez保存原始.data和.targets
import os
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
data_dir = './data' #定义数据存储目录
train_npz_path = os.path.join(data_dir, 'trainset_raw.npz') #训练集保存路径(原始格式)
test_npz_path = os.path.join(data_dir, 'testset_raw.npz') #测试集保存路径(原始格式)
#定义函数:保存原始数据集为.npz文件
#注意:不使用 transform,保留原始uint8图像数据
def save_raw_dataset(dataset, path):
np.savez(path, data=dataset.data, targets=np.array(dataset.targets)) #保存图像数据和标签
print(f"Saved raw dataset to {path}") #打印保存路径
#定义函数:从.npz文件中加载原始数据集
def load_raw_dataset(path):
loaded = np.load(path) #加载.npz文件
data = torch.tensor(loaded['data'], dtype=torch.float32) #将uint8图像数据转换为float32类型张量
data = data.permute(0, 3, 1, 2) #通道维度转换:从 [N, H, W, C] 转为 [N, C, H, W]
data /= 255.0 #将像素值归一化到 [0, 1]
targets = torch.tensor(loaded['targets'], dtype=torch.long) #将标签转换为 long 类型张量
return TensorDataset(data, targets) #返回张量数据集对象
#如果原始数据集文件已存在,则直接加载
if os.path.exists(train_npz_path) and os.path.exists(test_npz_path):
print("Loading raw dataset from.npzfiles...") #打印提示信息
trainset = load_raw_dataset(train_npz_path) #加载训练集
testset = load_raw_dataset(test_npz_path) #加载测试集
else:
print("Downloading CIFAR10 dataset and saving raw data to .npz...") #打印提示信息
#下载原始 CIFAR10 数据集(不使用 transform,保留原始格式)
train_original = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=None)
test_original = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=None)
#保存原始训练集和测试集为.npz文件
save_raw_dataset(train_original, train_npz_path)
save_raw_dataset(test_original, test_npz_path)
#加载保存后的数据集
trainset = load_raw_dataset(train_npz_path)
testset = load_raw_dataset(test_npz_path)
#创建训练集的数据加载器,批大小为64,打乱数据顺序
train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
#创建测试集的数据加载器,批大小为64,不打乱数据顺序
test_loader = DataLoader(testset, batch_size=64, shuffle=False)
#打印训练集和测试集的样本数量
print(f"Train samples: {len(trainset)}, Test samples: {len(testset)}")
方法一:使用torch.save/torch.load
这是最直接、最符合PyTorch用户习惯的做法。我们将经过transform处理后的CIFAR10数据集对象直接使用torch.save()序列化为.pt文件,之后使用torch.load()还原为完整对象。
这种方式的优点是简单快捷,代码量少,保存和加载速度都非常快;同时能保留PyTorch对象原始结构,直接配合DataLoader使用即可开始训练,无需任何额外处理。但它的缺点也明显,由于使用Python的pickle序列化,跨平台兼容性较差;如果数据集对象包含了复杂的transform或定制类,可能在不同版本PyTorch或其他语言环境中无法加载。
因此,这种方法适合自己训练模型、本地复现实验或快速缓存数据集,但不适合共享或长期存档。
方法二:使用numpy.savez保存transform后的张量(trainset[i][0])
该方法遍历整个数据集,依次取出transform后的图像张量(通常为[C,H,W]且已经归一化到[0,1]),和其对应标签,统一保存到.npz文件中。这种方式保存的数据格式已经符合训练要求,加载后可直接转换为TensorDataset使用,无需做通道转置、归一化等处理。
它的优点在于兼容性强(保存为标准NumPy格式),加载后无须复杂预处理即可用于训练。相比方法一,它更适合于数据共享、跨平台模型部署等使用场景。但因为要遍历整个数据集,转换为NumPy,再存盘,所以保存过程比方法一稍慢;同时由于是保存归一化后的float32张量,文件体积也比原始数据略大。
整体而言,方法二是一个兼顾实用性和通用性的方法,推荐在需要长期保存数据或和他人共享实验数据时使用。
方法三:使用numpy.savez保存原始.data和.targets
这种方式不依赖transform,而是直接读取CIFAR10数据集对象的.data属性(即原始的RGB图像,[H,W,C],uint8类型)和.targets(整数列表),保存为.npz文件。相比方法二,它速度更快,占用空间更小(因为保存的是uint8图像而非float32),但缺点是加载后不能直接用于训练。
在加载过程中,需要手动对数据进行以下处理:首先将图像张量从[N,H,W,C]转换为[N,C,H,W];然后将像素值从整数类型归一化为[0,1]的浮点数。这使得代码相对复杂一些,但也给予了更大的灵活性。
这种方法适合对数据预处理流程有特殊要求的用户,或希望将数据提供给非PyTorch平台(如TensorFlow、Keras、ONNX等)使用,或出于存储/发布原因希望保留原始图像的原貌。