读取图片数据集

文章介绍了如何利用torchvision.datasets.ImageFolder读取和处理图片数据集,包括对图片进行随机裁剪、划分训练集和测试集,以及将图片转化为张量数据的过程。通过使用transforms.Compose将图片转换为模型训练所需的格式。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

读取文件夹的图片数据

在处理图片数据集的时候,会出现怎么将图片数据集转化为可以用于训练和测试的张量数据集的问题。

进行分类的情况下:一般图片数据集的存放方法如下:

目录结构

需要用到的函数torchvision.datasets.ImageFolder()

ImageFolder(root,transform=None,target_transform=None,loader=default_loader)

root: 图片总目录,子层级为各类型对应的文件目录。
transform: 对PIL image进行转换操作,transform输入的loader读取图片返回的对象
target_transform: 对label进行变换
loader: 指定加载图片的函数,默认操作是读取为RGB格式的PIL image对象 

实际使用实践:

from torchvision import transforms,utils
from torchvision import datasets
import torch
import matplotlib.pyplot as plt

#对图片进行随机裁剪256的大小
trans = transforms.Compose([
    transforms.RandomResizedCrop(256)
])

#使用ImageFolder读取图片
train_data = datasets.ImageFolder('data/flower_photos/train',transform=trans)

#读取了所有类别下的所有图片
print(len(train_data))
3669
 
#获取类别
print(train_data.classes)
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

#类别与标签索引
print(train_data.class_to_idx)
{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}

#图片对应类别标签
print(train_data.targets)
[0,1,2,3,4]


#每个元素是元组
data[0]
(<PIL.Image.Image image mode=RGB size=256x256 at 0x2120C9A07F0>, 0)


#查看所有图片路径与类别标签
data.imgs
('data/flower_photos/train\\daisy\\2611119198_9d46b94392.jpg', 0)

在训练时,一般会先按照一定的比例划分训练集和测试集,下面是文件操作,将文件夹下的文件按照3:1分别分为train和test两个文件夹。


# 将花图片以3:1的方法分配到train和test文件下
data_dir='data/flower_photos'
classes=['daisy','dandelion','roses','sunflowers','tulips']
train_dir=os.path.join(data_dir,'train')
test_dir=os.path.join(data_dir,'test')

for cls in classes:
    os.makedirs(os.path.join(train_dir,cls),exist_ok=True)
    os.makedirs(os.path.join(test_dir,cls),exist_ok=True)

for cls in classes:
    src_dir=os.path.join(data_dir,cls)
    files=os.listdir(src_dir)
    np.random.shuffle(files)
    train_files=files[:int(len(files)*0.75)]
    test_files=files[int(len(files)*0.75):]

    for file in train_files:
        src_path=os.path.join(src_dir,file)
        dst_path=os.path.join(train_dir,cls,file)
        print(src_path)
        img=Image.open(src_path)
        img.save(dst_path)
    for file in test_files:
        src_path=os.path.join(src_dir,file)
        dst_path=os.path.join(test_dir,cls,file)
        print(src_path)
        img=Image.open(src_path)
        img.save(dst_path)

在使用数据时候,还需要使用transforms.Compose()将图片转化为张量数据:

#使用transform将图片向量转化为张量
transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5),(0.5))
])
# 使用torchvision.datasets.ImageFolder将图片数据集读出来
train_data=torchvision.datasets.ImageFolder('data/flower_photos/train',transform=transform)
test_data=torchvision.datasets.ImageFolder('data/flower_photos/train',transform=transform)
# train_data.classes('daisy','dandelion','roses','sunflowers','tulips')
# print(train_data.classes)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值