# -*- coding: utf-8 -*-
'''
@Time : 2020/5/20 22:28
@Author : HHNa
@FileName: dataset.py
@Software: PyCharm
'''
import os, sys, glob, shutil, json
import cv2
from PIL import Image
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# 原始SVHN中类别10为数字0
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [
pip安装PyTorch =1.2、torchVision = 0.3时报错ImportError: DLL load failed: 原因是版本没对!
最新推荐文章于 2025-02-19 22:43:04 发布