from torch.utils.data import Dataset
import cv2
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(root_dir, label_dir)
self.image_dir_list = os.listdir(self.path)
def __getitem__(self, item):
image_dir = os.path.join(self.path, item)
image = cv2.imread(image_dir)
label = self.label_dir
return image, label
def __len__(self):
return len(self.image_dir_list)
if __name__ == '__main__':
mydata = MyData(root_dir="hymenoptera_data/train", label_dir="ants")
image_dir_list = mydata.image_dir_list
for index in image_dir_list:
cv2.imshow(mydata.__getitem__(index)[1],mydata.__getitem__(index)[0])
cv2.waitKey(1000)
cv2.destroyAllWindows()
pytorch创建数据集
最新推荐文章于 2024-11-22 21:17:48 发布
部署运行你感兴趣的模型镜像
您可能感兴趣的与本文相关的镜像
PyTorch 2.5
PyTorch
Cuda
PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理
1835

被折叠的 条评论
为什么被折叠?



