pytorch 图像数据集管理

本文围绕Pytorch展开,介绍了数据集管理方法,使用Dataset管理训练和测试数据集,通过DataLoader容器取出批次数据训练模型。还说明了可继承Dataset类管理数据,以及图像分类常用类ImageFolder的使用,它能自动分类图片形成数据集,且可将数据集分为训练和测试集。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

1.数据集的管理说明

2.数据集Dataset类说明

3.图像分类常用的类 ImageFolder


1.数据集的管理说明

        pytorch使用Dataset来管理训练和测试数据集,前文说过 

torchvision.datasets.MNIST

        这些 torchvision.datasets里面的数据集都是继承Dataset而来,对Datasetd 管理使用DataLoader我们使用的的时候,只需要把Dataset类放在DataLoader这个容器里面,在训练的时候 for循环从DataLoader容器里面取出批次的数据,对模型进行训练。

2.数据集Dataset类说明

        我们可以继承Dataset类,对训练和测试数据进行管理,继承Dataset示例:

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
import os
import cv2
#继承from torch.utils.data import Dataset
class CDataSet(Dataset):
    def __init__(self,path):
        self.path = path
        self.list = os.listdir(path)
        self.len = len(self.list)
        self.name = ['cloudy','rain','shine','sunrise']
        self.trans = transforms.ToTensor()
    def __len__(self):
        return self.len
    def __getitem__(self, item):
        self.imgpath = os.path.join(self.path,self.list[item])
        print(self.imgpath)
        img = cv2.imread(self.imgpath)
        img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        img = cv2.resize(img,(100,100))
        img = self.trans(img)
        for i,n in enumerate(self.name):
            if n in self.imgpath:
                label = i+1
                break
        return img,label

ds = CDataSet(r'E:\test\pythonProject\dataset\cloudy')
dl = DataLoader(ds,batch_size=16,shuffle=True)
print(len(ds))
print(len(dl))
print(type(ds))
print(type(dl))
print(next(iter(dl)))


'''
D:\anaconda3\python.exe E:\test\pythonProject\test.py 
300
19
<class '__main__.CDataSet'>
<class 'torch.utils.data.dataloader.DataLoader'>
E:\test\pythonProject\dataset\cloudy\cloudy294.jpg
E:\test\pythonProject\dataset\cloudy\cloudy156.jpg
E:\test\pythonProject\dataset\cloudy\cloudy149.jpg
E:\test\pythonProject\dataset\cloudy\cloudy148.jpg
E:\test\pythonProject\dataset\cloudy\cloudy3.jpg
E:\test\pythonProject\dataset\cloudy\cloudy106.jpg
E:\test\pythonProject\dataset\cloudy\cloudy137.jpg
E:\test\pythonProject\dataset\cloudy\cloudy276.jpg
E:\test\pythonProject\dataset\cloudy\cloudy147.jpg
E:\test\pythonProject\dataset\cloudy\cloudy8.jpg
E:\test\pythonProject\dataset\cloudy\cloudy164.jpg
E:\test\pythonProject\dataset\cloudy\cloudy293.jpg
E:\test\pythonProject\dataset\cloudy\cloudy116.jpg
E:\test\pythonProject\dataset\cloudy\cloudy56.jpg
E:\test\pythonProject\dataset\cloudy\cloudy187.jpg
E:\test\pythonProject\dataset\cloudy\cloudy177.jpg
[tensor([[[[0.2235, 0.2471, 0.3569,  ..., 0.1490, 0.1373, 0.1373],
          [0.2902, 0.4039, 0.4078,  ..., 0.1529, 0.1373, 0.1294],
          [0.3294, 0.4941, 0.4000,  ..., 0.1529, 0.1333, 0.1137],
          ...,
          [0.0118, 0.0118, 0.0118,  ..., 0.0078, 0.0078, 0.0078],
          [0.0118, 0.0118, 0.011
世界地图矢量数据可以通过多种网站进行下载。以下是一些提供免费下载世界地图矢量数据的网站: 1. Open Street Map (https://www.openstreetmap.org/): 这个网站可以根据输入的经纬度或手动选定范围来导出目标区域的矢量图。导出的数据格式为osm格式,但只支持矩形范围的地图下载。 2. Geofabrik (http://download.geofabrik.de/): Geofabrik提供按洲际和国家快速下载全国范围的地图数据数据格式支持shape文件格式,包含多个独立图层,如道路、建筑、水域、交通、土地利用分类、自然景观等。数据每天更新一次。 3. bbbike (https://download.bbbike.org/osm/): bbbike提供全球主要的200多个城市的地图数据下载,也可以按照bbox进行下载。该网站还提供全球数据数据格式种类齐全,包括geojson、shp等。 4. GADM (https://gadm.org/index.html): GADM提供按国家或全球下载地图数据的服务。该网站提供多种格式的数据下载。 5. L7 AntV (https://l7.antv.antgroup.com/custom/tools/worldmap): L7 AntV是一个提供标准世界地图矢量数据免费下载的网站。支持多种数据格式下载,包括GeoJSON、KML、JSON、TopJSON、CSV和高清SVG格式等。可以下载中国省、市、县的矢量边界和世界各个国家的矢量边界数据。 以上这些网站都提供了世界地图矢量数据免费下载服务,你可以根据自己的需求选择合适的网站进行下载
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值