import os
import numpy as np
from PIL import Image
import pickle
def convert_cifar_to_images(data_path, output_dir, is_cifar100=False):
"""将CIFAR数据集转换为图像文件"""
os.makedirs(output_dir, exist_ok=True)
# 加载数据文件
with open(data_path, 'rb') as f:
data = pickle.load(f, encoding='latin1')
if is_cifar100:
# CIFAR-100 处理
meta = unpickle(os.path.join(os.path.dirname(data_path), 'meta'))
labels = meta['fine_label_names']
else:
# CIFAR-10 处理
with open(os.path.join(os.path.dirname(data_path), 'batches.meta'), 'rb') as meta_file:
meta = pickle.load(meta_file, encoding='latin1')
labels = meta['label_names']
# 创建类别文件夹
for label in labels:
os.makedirs(os.path.join(output_dir, label.decode() if isinstance(label, bytes) else label),
exist_ok=True)
# 转换并保存图像
images = data['data'] if 'data' in data else data
filenames = data.get('filenames', [])
label_list = data.get('labels', []) or data.get('fine_labels', [])
for idx, (pixels, filename, label_idx) in enumerate(zip(images, filenames, label_list)):
# 转换图像格式 (3072 → 3x32x32 → 32x32x3)
img = pixels.reshape(3, 32, 32).transpose(1, 2, 0)
# 获取标签名称
label = labels[label_idx]
label_name = label.decode() if isinstance(label, bytes) else label
# 创建文件名
if not filename:
filename = f"img_{idx}.png"
elif isinstance(filename, bytes):
filename = filename.decode()
# 保存图像
output_path = os.path.join(output_dir, label_name, filename)
Image.fromarray(img).save(output_path)
def unpickle(file):
"""加载 pickle 文件"""
with open(file, 'rb') as f:
return pickle.load(f, encoding='latin1')
# ===== CIFAR-10 转换 =====
cifar10_dir = r'C:\oodd_dataset\cifar-10-batches-py'
# 训练集
for i in range(1, 6):
convert_cifar_to_images(
os.path.join(cifar10_dir, f'data_batch_{i}'),
'cifar10_images/train'
)
# 测试集
convert_cifar_to_images(
os.path.join(cifar10_dir, 'test_batch'),
'cifar10_images/test'
)
# ===== CIFAR-100 转换 =====
cifar100_dir = r'C:\oodd_dataset\cifar-100-python'
convert_cifar_to_images(
os.path.join(cifar100_dir, 'train'),
'cifar100_images/train',
is_cifar100=True
)
convert_cifar_to_images(
os.path.join(cifar100_dir, 'test'),
'cifar100_images/test',
is_cifar100=True
)
print("转换完成!图像保存在以下文件夹中:")
print("CIFAR-10: cifar10_images/")
print("CIFAR-100: cifar100_images/")