pytorch学习(三):Dataset和Dataloader(加载数据)

自定义DataSet

数字

from torch.utils.data import DataLoader
from torch.utils.data.dataset import TensorDataset
import torch

# 自构建数据集
dataset = TensorDataset(torch.arange(1, 40))
dl = DataLoader(dataset, batch_size=10)
# 数据输出
for batch in dl:
    print(batch)

图片

from torch.utils.data import Dataset
from PIL import Image   #这里用来读取图片数据
import os

class MyData(Dataset):  # 这里定义了一个MyData类继承Dataset读取数据

    def __init__(self, root_dir, label_dir):        #进行初始化数据
        self.root_dir = root_dir        #比如 "data/train"  相对路径
        self.label_dir = label_dir      #这里可以是 "ant"
        self.path = os.path.join(self.root_dir,self.label_dir)      #利用一个函数将其合并为ant标签对应的路径
        # os.listdir()会将path路径下对应的所有文件的“文件名”转化成列表中的一个个元素
        self.img_path = os.listdir(self.path)

    def __getitem__(self, idx):  #实现这个方法是为了获取第idx个文件对应的数据和标签的
        img_name = self.img_path[idx]   #这个列表就可以返回名字呢
        img_item_path = os.path.join(self.path,img_name)    #将其之前的路径和文件名拼起来就是最终路径了
        img = Image.open(img_item_path)  #数据读取到img变量里了
        label = self.label_dir  #标签就是上层文件夹名
        return img,label       #返回img,label完成重写函数任务

    def __len__(self):
        return len(self.img_path)  #列表的长度就是数据的长度
root_dir = "../../../../../datas/hymenoptera_data/train"
ants_lable_dir = "ants"
bees_lable_dir = "bees"

imageList = MyData(root_dir, ants_lable_dir)
imge, lable = imageList[1]
imge.show()
print(lable)

显示带有地标的图片

from __future__ import print_function, division
import os
import torch
import pandas as pd              #用于更容易地进行csv解析
from skimage import io, transform    #用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# 忽略警告
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

# scikit-image:用于图像的IO和变换
# pandas:用于更容易地进行csv解析


# 读取数据
# 将csv中的标注点数据读入(N,2)数组中,其中N是特征点的数量。读取数据代码如下:
landmarks_frame = pd.read_csv('../../../../../datas/faces/face_landmarks.csv')

n = 67
# 获取第n行的图像名称(第0列)
img_name = landmarks_frame.iloc[n, 0]
print('Image name: {}'.format(img_name))
# 获取第n行的地标数据(第1列及之后的所有列),并转换为矩阵
landmarks = landmarks_frame.iloc[n, 1:].to_numpy()
print(landmarks)
# 将地标数据类型转换为浮点型,并重塑为二维数组,每个地标点包含两个值(x和y坐标)
print(landmarks.astype('float'))
landmarks = landmarks.astype('float').reshape(-1, 2)
print(landmarks)

print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))


def show_landmarks(image, landmarks):
    """显示带有地标的图片"""
    # 使用 matplotlib.pyplot 的 imshow 函数显示图像。
    plt.imshow(image)
    # 使用 scatter 函数在图像上绘制地标点。
    # landmarks[:, 0] 和 landmarks[:, 1] 分别表示地标点的 x 坐标和 y 坐标,
    # s=10 设置点的大小,marker='.' 设置点的形状为小圆点,c='r' 设置点的颜色为红色。
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    # 暂停0.001秒,以确保图像更新并显示出来。这在某些环境中是必要的,以确保绘图窗口能够及时刷新。
    plt.pause(0.001)  # pause a bit so that plots are updated

plt.figure()
show_landmarks(io.imread(os.path.join('../../../../../datas/faces/', img_name)),landmarks)
plt.show()

dataloader() 方法

上面我们知道了dataset 这个数据集之后,为啥还需要适用datalaoder, 这个时候如果你是从java过来的, 应该知道classloader,需要对内容进行加载,如果不是从java过来也没关系。
dataset 我们可以比作时一副扑克牌,但是我们怎么抓到每个人手里是有dataloader说了算

简单的例子

from torch.utils.data import DataLoader
from torch.utils.data.dataset import TensorDataset
import torch

# 自构建数据集
dataset = TensorDataset(torch.arange(1, 40))
r"""
batch_size 一批加载多少个数据
shuffle 随机的意思, 默认是false, 如果是True 每次加载的数据顺序就不一样
drop_last 最后不够batch_size 这么多时,数据是否丢弃
num_workers 表示几个子进程进行数据加载, 0 表示在主现场中加载
"""
dl = DataLoader(dataset, batch_size=9, shuffle = True, drop_last = True, num_workers = 0)
# 数据输出
for batch in dl:
    print(batch)

图片加载的例子

import torchvision.datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

data_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_data = torchvision.datasets.CIFAR10("../../../../../datas/download",train=True,transform=data_transform,download=True)
test_data = torchvision.datasets.CIFAR10("../../../../../datas/download",train=False,transform=data_transform,download=True)

test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True, num_workers=0, drop_last= True)

image, target = test_data[0]
print(image.shape)
print(target)

writer = SummaryWriter("../../../../../datas/03_tensorboard")

step = 0
for data in test_dataloader:
    imgs, targets =data
    # print(imgs.shape)
    # print(targets)
    writer.add_images("test_data_loader", imgs, step)
    step = step + 1

for epoch in range(2):
    step = 0
    for data in test_dataloader:
        imgs, targets =data
        # print(imgs.shape)
        # print(targets)
        writer.add_images("test_data_loader_epoch_{}".format(epoch), imgs, step)
        step = step + 1
writer.close()
PyTorch中,数据读取是构建深度学习模型的重要一环。为了高效处理大规模数据集,PyTorch提供了个主要的工具:DatasetDataLoaderTensorDatasetDataset是一个抽象类,用于自定义数据集。我们可以继承Dataset类,并重写其中的__len____getitem__方法来实现自己的数据加载逻辑。__len__方法返回数据集的大小,而__getitem__方法根据给定的索引返回样本对应的标签。通过自定义Dataset类,我们可以灵活地处理各种类型的数据集。 DataLoader数据加载器,用于对数据集进行批量加载。它接收一个Dataset对象作为输入,并可以定义一些参数例如批量大小、是否乱序等。DataLoader能够自动将数据集划分为小批次,将数据转换为Tensor形式,然后通过迭代器的方式供模型训练使用。DataLoader数据准备模型训练的过程中起到了桥梁作用。 TensorDataset是一个继承自Dataset的类,在构造时将输入数据目标数据封装成Tensor。通过TensorDataset,我们可以方便地处理Tensor格式的数据集。TensorDataset可以将多个Tensor按行对齐,即将第i个样本从各个Tensor中取出,构成一个新的Tensor作为数据集的一部分。这对于处理多输入或者多标签的情况非常有用。 总结来说,Dataset提供了自定义数据集的接口,DataLoader提供了批量加载数据集的能力,而TensorDataset则使得我们可以方便地处理Tensor格式的数据集。这个工具的配合使用可以使得数据处理变得更加方便高效。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值