PyTorch教程_数据的加载(小土堆)

安装环境部分:
1、Anaconda的安装(自行参考其他博客)
2、使用pip安装pytorch

  • 打开anaconda prompt创建一个pytorch的环境
  • conda create -n pytorch_gpu python=3.7 pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
  • 安装COUD11(可以查看自己显卡的种类去pytorch官网上指定下载)
  • 测试 import
  • torch.cuda.is_available()
  • end

代码部分:

import os

from PIL import Image
from torch.utils.data import Dataset


# dataset有两个作用:1、加载每一个数据,并获取其label;2、用len()查看数据集的长度
class MyData(Dataset):
    def __init__(self, root_dir, label_dir):  # 初始化,为这个函数用来设置在类中的全局变量
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir)  # 单纯的连接起来而已,背下来怎么用就好了,因为在win下和linux下的斜线方向不一样,所以用这个函数来连接路径
        self.img_path = os.listdir(self.path)  # img_path 的返回值,就已经是一个列表了

    def __getitem__(self, idx):  # 获取数据对应的 label
        img_name = self.img_path[idx]  # img_name 在上一个函数的最后,返回就是一个列表了
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)  # 这行的返回,是一个图片的路径,加上图片的名称了,能够直接定位到某一张图片了
        img = Image.open(img_item_path)  # 这个步骤看来是不可缺少的,要想 show 或者 操作图片之前,必须要把图片打开(读取),也就是 Image.open()一下,这可能是 PIL 这个类型图片的特有操作
        label = self.label_dir  # 这个例子中,比较特殊,因为图片的 label 值,就是图片所在上一级的目录
        return img, label  # img 是每一张图片的名称,根据这个名称,就可以使用查看(直接img)、print、size等功能
        # label 是这个图片的标签,在当前这个类中,标签,就是只文件夹名称,因为我们就是这样定义的

    def __len__(self):
        return len(self.img_path)  # img_path,已经是一个列表了,len()就是在对这个列表进行一些操作


if __name__ == '__main__':
    root_dir = "dataset/train"
    ants_label_dir = "ants"
    bees_label_dir = "bees"
    ants_dataset = MyData(root_dir, ants_label_dir)
    bees_dataset = MyData(root_dir, bees_label_dir)
    train_dataset = ants_dataset + bees_dataset
    
# runfile('E:/pythonProject/learn_pytorch/read_data.py')
### 关于 PyTorch教程 #### 安装与导入 PyTorch 是一个强大的开源深度学习框架,以其灵活性和动态计算图而广受欢迎。对于初学者来说,了解如何安装并导入 PyTorch 至关重要[^1]。 为了使 PyTorch 能够运行在带有 NVIDIA GPU 的计算机上,并利用其硬件加速功能,建议安装 GPU 版本的 PyTorch。这需要确保系统已安装 CUDA 和 cuDNN 库[^3]。 如果仅需使用 CPU 进行运算,则可以选择安装 CPU 版本的 PyTorch。此过程相对简单得多,适合那些不具备合适 GPU 设备的学习者或开发者。 #### 功能特性概述 PyTorch 提供了一个易于使用的 API 接口,使得编写代码如同日常处理 Python 数据结构一般轻松自在;它能够很好地融入现有的 Python 数据科学工作流之中。此外,由于其设计上的相似性,熟悉 NumPy 的用户会发现过渡到 PyTorch 十分容易[^4]。 该框架还具备支持分布式训练的能力,这意味着可以在多个节点之间分配任务以提高效率。同时,官方也提供了多种工具来帮助用户更好地管理和部署经过训练后的模型至生产环境中去。 最后值得一提的是,围绕着 PyTorch 形成了一个庞大且活跃度极高的社区生态体系,其中包括但不限于各种第三方扩展包以及预构建模块,这些资源极大地促进了诸如计算机视觉、自然语言处理等领域内项目的快速迭代与发展。 #### 实际操作指南 下面给出一段简单的 Python 代码片段用于展示如何加载数据集: ```python import torch from torchvision import datasets, transforms transform = transforms.Compose([transforms.ToTensor()]) dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True) for images, labels in dataloader: pass # Do something with the data here. ``` 这段程序展示了如何通过 `torchvision.datasets` 下载 MNIST 手写数字识别数据库,并创建相应的 DataLoader 对象以便后续批量读取图像样本及其标签信息。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值