#CamVid数据集介绍及读取,附代码(Pytorch版本)
最近想在Bisenet网络上测试一下CamVid数据集,CamVid数据集也是第一次接触,然后想借阅一些别人的博客,抄近道,一查之后才发现。。。。。。。。
所以我想把自己理解的跟大家分享一下。
我也会把自己跑通的Bisenet代码也会分享到该博客上。
##数据集简介
CamVid数据集的文件布局如下所示:

其中class_dict是将数据集中每个物体使用RGB三通道的颜色值进行分类的。

这里面的class_11说的是数据集中存在的,且进行分割的物体。(这个是我个人理解,如果有误,欢迎各位在评论区留言。)
这里说了只是简介,想了解更多,只需要自己下载,查看一下便知。
读取代码解析
上代码!!!!!
下面的是保存为CamVid.py
import os
import torch
import glob
import os
from torchvision import transforms
#import cv2
from PIL import Image
import pandas as pd
import numpy as np
#from imgaug import augmenters as iaa
#import imgaug as ia
from utils import get_label_info, one_hot_it, RandomCrop, reverse_one_hot, one_hot_it_v11, one_hot_it_v11_dice
import random
def augmentation():
# augment images with spatial transformation: Flip, Affine, Rotation, etc...
# see https://github.com/aleju/imgaug for more details
pass
def augmentation_pixel():
# augment images with pixel intensity transformation: GaussianBlur, Multiply, etc...
pass
class CamVid(torch.utils.data.Dataset):
def __init__(self, image_path, label_path, csv_path, scale, loss='dice', mode='train'):
super().__init__()
self.mode = mode
self.image_list = []
if not isinstance(image_path, list):
image_path = [image_path]
for image_path_ in image_path:
self.image_list.extend(glob.glob(os.path.join(image_path_, '*.png')))
self.image_list.sort()
self.label_list = []
if not isinstance(label_path, list):
label_path = [label_path]
for label_path_ in label_path:
self.label_list.extend(glob.glob(os.path.join(label_path_, '*.png')))
self.label_list.sort()
# self.image_name = [x.split('/')[-1].split('.')[0] for x in self.image_list]
# self.label_list = [os.path.join(label_path, x + '_L.png') for x in self.image_list]
# self.fliplr = iaa.Fliplr(0.5)
self.label_info = get_label_info(csv_path)
# resize
# self.resize_label = transforms.Resize(scale, Image.NEAREST)
# self.resize_img = transforms.Resize(scale, Image.BILINEAR)
# normalization
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
# self.crop = transforms.RandomCrop(scale, pad_if_needed=True)
self.image_size = scale
self.scale = [0.5, 1, 1.25, 1.5, 1.75, 2]
self.loss = loss
def __getitem__(self, index):
# load image and crop
seed = random.random()
img = Image.open(self.image_list[index])
img.show()
# random crop image
# =====================================
# w,h = img.size
# th, tw = self.scale
# i = random.randint(0, h - th)
# j = random.randint(0, w - tw)
# img = F.crop(img, i, j, th, tw)
# =====================================
# print(self.scale)
scale = random.choice(self.scale)
# print(scale)
scale = (int(self.image_size[0] * scale), int(self.image_size[1] * scale))
# print(scale)
# print(self.image_size)
# print(scale)
# randomly resize image and random crop
# =====================================
if self.mode == 'train':
img = transforms.Resize(scale, Image.BILINEAR)(img)
img = RandomCrop(self.image_size, seed, pad_if_needed=True)(img)
# =====================================
img = np.array(img)
# # load label
label = Image.open(self.label_list[index])
label.show()
# # crop the corresponding label
# # =====================================
# # label = F.crop(label, i, j, th, tw)
# # =====================================
#
# # randomly resize label and random crop
# # =====================================
if self.mode == 'train':
label =
CamVid数据集读取与处理

本文详细介绍了如何使用Pytorch读取并处理CamVid数据集,包括数据集布局、分类字典的使用、图像和标签的预处理方法,以及如何通过自定义的getitem()函数配合dataloader进行高效的数据读取。
最低0.47元/天 解锁文章
5948

被折叠的 条评论
为什么被折叠?



