import os
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import datasets, models, transforms
class dataLoader(Dataset):
def __init__(self, path, listName, dataset = '', data_transforms = None, loader = None):
self.path = path
self.listName = listName
self.images = [os.path.join(self.path, line.strip().split()[0]) for line in open(self.listName)]
self.labels = [int(line.strip().split()[1]) for line in open(self.listName)]
self.data_transforms = data_transforms
self.dataset = dataset
if loader:
self.loader = loader
else:
self.loader = self.default_loader
def default_loader(self, imageName):
try:
image = Image.open(imageName)
return image.convert('RGB')
except:
print("Cannot read image", path)
def __len__(self):
pytorch学习笔记(2)-使用自定义txt文件读取数据
最新推荐文章于 2025-05-10 21:25:53 发布