paddle 自定义数据集和预处理

本文介绍了如何在PaddlePaddle库中创建自定义数据集,包括继承Dataset类、实现__init__和__getitem__/__len__方法,以及使用预处理功能如ToTensor、数据增强等。作者还展示了如何使用Compose对图像进行预处理,如中心裁剪、随机水平翻转和颜色调整。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

自定义数据集 

import paddle
from matplotlib import pyplot as plt
import os
import cv2
import numpy as np
from paddle.io import Dataset
from paddle.vision.transforms import Normalize
print('计算机视觉(CV)相关数据集:', paddle.vision.datasets.__all__)
print('自然语言处理(NLP)相关数据集:', paddle.text.__all__)

#图像数据集
# vis_dataset=paddle.vision.datasets.MNIST(mode='train',transform=paddle.vision.transforms.ToTensor())
# print(len(vis_dataset))
# image,label=vis_dataset[0]
# print(type(image))
# print(image.shape)
# print(label)

#文字数据集
# text_dataset=paddle.text.Imdb()
# text,label=text_dataset[1]
# print(type(text))
#
# print(label)

#图像显示
# for data in vis_dataset:
#     image,label=data
#     print('图片的shape',image.shape)
#     plt.title(str(label))
#     plt.imshow(image[0])
#     plt.show()

#自定义数据集
class MyDataset(Dataset):
    """
    步骤一:继承 paddle.io.Dataset 类
    """
    def __init__(self, data_dir, label_path, transform=None):
        """
        步骤二:实现 __init__ 函数,初始化数据集,将样本和标签映射到列表中
        """
        super().__init__()
        self.data_list = []
        with open(label_path,encoding='utf-8') as f:
            for line in f.readlines():
                image_path, label = line.strip().split('\t')
                image_path = os.path.join(data_dir, image_path)
                self.data_list.append([image_path, label])
        # 传入定义好的数据处理方法,作为自定义数据集类的一个属性
        self.transform = transform

    def __getitem__(self, index):
        """
        步骤三:实现 __getitem__ 函数,定义指定 index 时如何获取数据,并返回单条数据(样本数据、对应的标签)
        """
        # 根据索引,从列表中取出一个图像
        image_path, label = self.data_list[index]
        # 读取灰度图
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        # 飞桨训练时内部数据格式默认为float32,将图像数据格式转换为 float32
        image = image.astype('float32')
        # 应用数据处理方法到图像上
        if self.transform is not None:
            image = self.transform(image)
        # CrossEntropyLoss要求label格式为int,将Label格式转换为 int
        label = int(label)
        # 返回图像和对应标签
        return image, label

    def __len__(self):
        """
        步骤四:实现 __len__ 函数,返回数据集的样本总数
        """
        return len(self.data_list)

# 定义图像归一化处理方法,这里的CHW指图像格式需为 [C通道数,H图像高度,W图像宽度]
transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
# 打印数据集样本数
train_custom_dataset = MyDataset('mnist/train','mnist/train/label.txt', transform)
print('train_custom_dataset images: ',len(train_custom_dataset))

train_dataloader=paddle.io.DataLoader(dataset=train_custom_dataset,batch_size=64,shuffle=True,num_workers=0,drop_last=True)
#迭代dataloader并且显示图片
for batch_data in train_dataloader:
    image, label = batch_data
    col=10
    row=7
    for i in range(len(image)):

        plt.subplot(row,col ,i+1)
        plt.title(str(label[i].numpy()))
        plt.xticks([])
        plt.yticks([])
        plt.imshow(image[i][0])
    plt.show()
    break

#batchsamper
# from paddle.io import BatchSampler
#
# bs=BatchSampler(train_custom_dataset,batch_size=8,shuffle=True,drop_last=True)
# print('batchsamer每轮返回一个索引列表')
# for batch_indices in bs:
#     print(batch_indices)
#     break

数据预处理

import cv2
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from paddle.vision.transforms import CenterCrop,RandomHorizontalFlip,Compose,ColorJitter


transform =Compose([CenterCrop(20),
                    RandomHorizontalFlip(0.5),#基于概率来执行图片的水平翻转
                    ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)#随机调整图像的亮度、对比度、饱和度和色调
                    ])

#1、opencv读取图片
# image = cv2.imread('0.jpg')
# print(image.shape)
# image_after_transform = transform(image)
# print(image_after_transform.shape)
# plt.subplot(1,2,1)
# plt.title('origin image')
# plt.imshow(image[:,:,::-1])
# plt.subplot(1,2,2)
# plt.title('transform image')
# plt.imshow(image_after_transform[:,:,::-1])
# plt.show()

# 2、PIL读取图片
image=Image.open('0.jpg')
image_after_transform=transform(image)
plt.subplot(1,2,1)
plt.title('origin image')
plt.imshow(image)
plt.subplot(1,2,2)
plt.title('transform image')
plt.imshow(image_after_transform)
plt.show()

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一壶浊酒..

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值