小土堆pytorch学习笔记 之数据加载

部署运行你感兴趣的模型镜像

Pytorch 之数据加载

Dataset 是一个抽象类,用于定义数据集的加载和预处理方式。当你需要自定义数据集时,可以通过继承 Dataset 类并实现特定的方法来创建自己的数据集。

  1. Dataset 类位于 torch.utils.data 模块中。

  2. 要使用 Dataset 类,你需要继承它并实现以下两个方法:

    1. len(self):返回数据集中样本的数量。
    2. getitem(self, idx):根据给定的索引 idx 返回一个样本。
  3. 为了方便上面两个函数的书写,还需要再加一个初始化函数来存文件名相关的参量

[!TIP]
下面先来学习几个函数!(源于 kimi)

1 文件操作相关函数

os.pathos 模块中的一个子模块,专门用于处理文件路径。它提供了一组用于路径操作的函数,这些函数可以跨平台工作,即在 Windows、Linux 和 macOS 等操作系统上都能正常使用。

os.path.join 会根据操作系统使用正确的斜杠。

  • os.path.join(path, *paths):将多个路径组件合并成一个路径。
  • os.path.split(path):将路径分解为头和尾两部分。
  • os.path.basename(path):返回路径的基础文件名。
  • os.path.dirname(path):返回路径的目录名。
  • os.path.abspath(path):返回路径的绝对路径。
  • os.path.normpath(path):规范化路径。
  • os.path.exists(path):检查路径是否存在。
  • os.path.isfile(path):检查路径是否为文件。
  • os.path.isdir(path):检查路径是否为目录。
  • os.path.splitext(path):将路径分解为文件名和扩展名。

2 图片导入相关函数

2.1 通过 pillow 库

  1. 安装 Pillow 库

首先,确保你已经安装了 Pillow 库。如果没有安装,可以通过 pip 安装:

pip install Pillow
  1. 导入图片

使用 Pillow 的 Image 模块,你可以轻松地打开和处理图片。

from PIL import Image
_打开图片文件_
image = Image.open('path/to/your/image.jpg')
_显示图片(需要安装Pillow的额外依赖)_
image.show()
  1. 处理图片(可选)

在存储之前,你可以对图片进行各种处理,比如裁剪、旋转、调整大小等。

_旋转图片_
rotated_image = image.rotate(90)
_调整图片大小_
resized_image = image.resize((300, 300))
_裁剪图片_
cropped_image = image.crop((50, 50, 200, 200))
  1. 存储图片

处理完图片后,你可以将其保存到文件系统中。

_保存图片_
image.save('path/to/save/image.jpg')
_保存为不同的格式_
image.save('path/to/save/image.png', 'PNG')

2.2 使用 OpenCV

如果你需要更复杂的图像处理功能,可以使用 OpenCV 库。首先安装 OpenCV:

pip install opencv-python

然后,你可以使用 OpenCV 来读取和保存图片:

import cv2
_读取图片_
image = cv2.imread('path/to/your/image.jpg')
_显示图片(使用OpenCV的窗口)_
cv2.imshow('Image', image)
cv2.waitKey(0)
cv2.destroyAllWindows()
_保存图片_
cv2.imwrite('path/to/save/image.jpg', image)

2.3 注意事项

  • 当处理图像文件时,确保你有权限读取和写入指定的文件路径。
  • 保存图像时,可以通过指定不同的格式来改变图像的保存格式,例如 JPEG、PNG 等。
  • 在使用图像处理库时,了解库的文档是非常重要的,因为不同的库可能提供不同的功能和方法。

[!TIP]
正式开始编代码!
回忆一些最最基础的 python 知识,函数/类的定义后要加:,缩进表示范围

3 如何使用 Dataset

导入相关模块,定义 MyData 类,并重写 len(self)和getitem(self, idx)

  1. init 函数初始化一些文件名的操作,方便从电脑文件中导入
def init(self,root_dir,image_dir,label_dir):
    _#文件路径相关_
_    _self.root_dir = root_dir
    self.image_dir = image_dir
    self.label_dir = label_dir
    self.image_path = os.path.join(root_dir,image_dir)
    self.label_dir = os.path.join(root_dir,label_dir)
  1. 通过编号 idx 来取出某张图片“类名[idx]”
def getitem(self,idx):
    _#os.listdir(path)返回一个列表,其中包含了目录中的所有条目,包括文件和子目录_
_    _self.imgs_path = os.listdir(self.image_path)
    img_item_path = os.path.join(self.image_path,self.imgs_path[idx])
    img = Image.open(img_item_path)
    label = self.label_dir
    return img,label
  1. 重写 len 函数
def len(self):
    return len(self.imgs_path)
    #img文件夹中图片的数量
  1. 如何调用/调试?

通过 pycharm 下面的 consle 来调试,可以非常方便的展示所有变量

#通过目录等信息创建实例
root_dir = "dataset\train"
label_dir = "ants_label"
image_dir = "ants_image"
ants_label_dir = "ants"
ants_dataset = MyData(root_dir,image_dir,label_dir) 

#取出第一张图片并显示
img,label = ants_dataset[0]    
img.show()

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

PyTorch 是一个广泛使用的深度学习框架,小土堆PyTorch 学习笔记涵盖了从基础到进阶的多个知识点,适合初学者系统性地学习。以下是一些关键内容的总结与扩展,结合了实际代码示例和常见应用场景。 ### 数据集的构建 在 PyTorch 中,自定义数据集通常需要继承 `torch.utils.data.Dataset` 类,并实现 `__init__`、`__getitem__` 和 `__len__` 方法。以下是一个简单的自定义数据集类 `MyData` 的实现,用于加载图像数据: ```python from torch.utils.data import Dataset from PIL import Image import os class MyData(Dataset): def __init__(self, root_dir, label_dir): self.root_dir = root_dir self.label_dir = label_dir self.path = os.path.join(self.root_dir, self.label_dir) self.img_path = os.listdir(self.path) def __getitem__(self, idx): img_name = self.img_path[idx] img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) img = Image.open(img_item_path) label = self.label_dir return img, label def __len__(self): return len(self.img_path) ``` 通过该类,可以轻松地加载图像数据集,并将其组合成训练集或验证集。例如,加载蚂蚁和蜜蜂两类图像数据并组合成一个完整的训练集: ```python root_dir_out = 'pytorch_xiaotudui/bee_ant/dataset' ants_label_dir = 'ants' ants_dataset = MyData(root_dir_out, ants_label_dir) bees_label_dir = 'bees' bees_dataset = MyData(root_dir_out, bees_label_dir) train_dataset = ants_dataset + bees_dataset # 合并两个数据集 ``` ### 模型预测与准确率计算 在模型训练和评估过程中,计算准确率是一个常见任务。PyTorch 提供了便捷的张量操作来实现这一目标。例如,假设模型的输出是一个二维张量,表示每个样本属于两个类别的概率,可以通过 `argmax` 方法获取预测结果: ```python import torch output = torch.tensor([[0.1, 0.2], [0.3, 0.4]]) preds = output.argmax(1) # 输出 [1, 1] targets = torch.tensor([0, 1]) # 计算预测是否正确 correct = (preds == targets) # 输出 [False, True] accuracy = correct.sum().item() / len(targets) # 计算准确率 print(f"Accuracy: {accuracy:.2f}") # 输出 Accuracy: 0.50 ``` ### 数据加载与批处理 为了提高训练效率,通常会使用 `DataLoader` 来批量加载数据。`DataLoader` 支持多线程加速和数据打乱等功能,适用于大规模数据集的训练。 ```python from torch.utils.data import DataLoader # 创建 DataLoader train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True) # 遍历数据 for images, labels in train_loader: # 进行模型训练或其他操作 print(images, labels) ``` ### 模型训练流程 PyTorch 的训练流程通常包括以下几个步骤: 1. **定义模型**:选择或自定义神经网络结构。 2. **定义损失函数**:如交叉熵损失 `CrossEntropyLoss`。 3. **定义优化器**:如随机梯度下降 `SGD` 或 Adam 优化器。 4. **前向传播**:输入数据经过模型得到输出。 5. **计算损失**:根据输出和真实标签计算损失。 6. **反向传播**:通过 `loss.backward()` 计算梯度。 7. **更新参数**:调用优化器的 `step()` 方法更新权重。 8. **清空梯度**:调用 `optimizer.zero_grad()` 避免梯度累积。 以下是一个简单的训练循环示例: ```python import torch.nn as nn import torch.optim as optim # 假设定义一个简单的模型 class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.fc = nn.Linear(10, 2) def forward(self, x): return self.fc(x) model = SimpleModel() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01) # 假设输入数据 inputs = torch.randn(2, 10) targets = torch.tensor([0, 1]) # 前向传播 outputs = model(inputs) loss = criterion(outputs, targets) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() print(f"Loss: {loss.item():.4f}") ``` ###
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值