Imagefolder(通用类型)
使用ImageFolder读取数据的前提是root目录下,文件夹名对应类名,每个文件夹下存储同一类别的图片。
root/dog/xxx.png
root/dog/xxy.png
root/cat/123.png
root/cat/nsdf3.png
torchvision.datasets.ImageFolder(root="root folder path", [transform, target_transform])
参数:
- root:在root指定的路径下寻找图片
- transform:对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
- target_transform:对label的转换
有三个成员变量
- self.classes - 用一个list保存 类名
- self.class_to_idx - 类名对应的 索引
- self.imgs - 保存(img-path, class) tuple的list
在cifar100的train/类别/图片上:
from torchvision import transforms as T
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
dataset = ImageFolder('train')
#print(dataset.classes)
#print(dataset.class_to_idx)
print(dataset.imgs)
参考:
- pytorch中文官网:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/#imagefolder
- 原文链接:https://blog.youkuaiyun.com/weixin_40123108/article/details/85099449
继承torch.utils.data.Datasets(通用类型)
基本流程:
- 根据数据集产生csv文件。格式 “文件路径 标签”
- 继承torch.utils.data.Datasets,通过上述csv文件读取图片。
- 使用
1、process.py产生train.py 文件路径 标签
import os
import pandas as pd
Train_DIR = "train"
Test_DIR = "test"
def get_train_pic(path) -> dict:
level_1 = os.listdir(path)
count = 0
base_dir = os.path.abspath('.')#当前绝对路径
result = dict()
for i in level_1:
real_path = os.path.join(base_dir,path, i)
print(real_path)
pic_path = os.listdir(real_path)
for i in pic_path:
kk = os.path.join(real_path, i)
# print(kk)
# print(count)
result[kk] = count
count += 1
#print(count)
return result
train = get_train_pic(Train_DIR)
test = get_train_pic(Test_DIR)
# print(test)
data_train = pd.DataFrame.from_dict(train, orient='index')
data_test = pd.DataFrame.from_dict(test, orient='index')
# print(data)
data_train.to_csv("train.csv")
data_test.to_csv("test.csv")
2、
import torch
import numpy as np
from torchvision import datasets, transforms
import torch.utils.data as data
from PIL import Image
import random
import os
import cv2
import random
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
raise(RuntimeError('No Module named accimage'))
else:
return pil_loader(path)
class ImageNetData(data.Dataset):
def __init__(self, img_root, img_file, is_training=False, transform=None, target_transform=None, loader=default_loader):
self.root = img_root
self.imgs = []
with open(img_file, 'r', encoding='utf-8') as fd:
for i, _line in enumerate(fd.readlines()):
infos = _line.strip().split()
if 2 != len(infos) : # Notice
continue
if is_training:
real_path = os.path.join(self.root, 'train', infos[0])
else:
real_path = os.path.join(self.root, 'test', infos[0])
class_id = int(infos[-1])
self.imgs.append((real_path, class_id))
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
path, class_id = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
class_id = torch.LongTensor([class_id])
return img, class_id
def __len__(self):
return len(self.imgs)
3、
train_datasets = ImageNetData(args.data_root,args.train_file,True,data_transforms['train'])
val_datasets = ImageNetData(args.data_root,args.val_file,False,data_transforms['val'])