Environment
- OS: macOS Mojave
- Python version: 3.7
- PyTorch version: 1.4.0
- IDE: PyCharm
文章目录
0. 写在前面
本文记录一下使用 PyTorch 读取图像数据。数据按照特定的目录结构放好之后,需要构建 Dataset 和 DataLoader 对数据进行读取。
Dataset 定义了读取数据的位置(data_dir)和方式(__getitem__),而 DataLoader 给出 indices 决定让 Dataset 读取哪些数据。
1. 构造 Dataset
以划分好的 TinyMind人民币面值识别 任务的训练集为例,目录结构如下
├── rmbdata
│ ├── categorise.py
│ ├── split.py
│ ├── test
│ │ ├── images.jpg
│ ├── train
│ │ ├── images.jpg
│ ├── train
│ │ ├── images.jpg
│ ├── train_face_value_label.csv
│ ├── val
│ │ ├── images.jpg
1.1 法一:定义 Dataset 子类
在 data 模块中定义一个 torch.utils.data.Dataset 的子类,对 __getitem__ 和 __len__ 方法进行重写,在训练或评估中导入使用
├── data.py
└── rmbdata
import os
from PIL import Image
from torch.utils.data import Dataset
class RMBFaceValueDataset(Dataset):
"""Dataset for classifying RMB by the face value"""
def __init__(self, data_dir, transform=None, class_to_idx=None):
"""
Params:
data_dir: str
directory of data
transform: torchvision.transform
data transform approaches
idx_to_class: dict
map class name to a specified number
"""
self.data_dir = data_dir
self.transform = transform
self.class_to_idx = class_to_idx
self.data_info = self._get_img_info()
def __getitem__(self, index):
"""
Receive an index and return an example with its label
Returns:
image: torch.Tensor or PIL.Image
image data
label: int
label for this example
"""
image_path, label = self.data_info[index]
image = Image.open(image_path).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image, label
def __len__(self):
return len(self.data_info)
def _get_img_info(self):
data_info = []
for root, dirs, _ in os

本文介绍了如何在PyTorch中读取图像数据,包括构造Dataset的两种方法——定义Dataset子类和使用ImageFolder类,以及如何构造DataLoader,特别是在处理类别不平衡问题时的采样器应用。
最低0.47元/天 解锁文章
7万+

被折叠的 条评论
为什么被折叠?



