boxcars数据集放在pkl文件中,首先需要读取 .pkl文件,定义一个读取函数。
import pickle
def load_cache(path, encoding="latin-1", fix_imports=True):
"""
encoding latin-1 is default for Python2 compatibility
"""
with open(path, "rb") as f:
return pickle.load(f, encoding=encoding, fix_imports=True)
下面放dataloader类
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import Dataset
from torchvision import models,transforms
import os
import time
import pickle
from PIL import Image
def default_loader(path, bb2d):
try:
img = Image.open(path)
img.crop((bb2d[0],bb2d[1],bb2d[0]+bb2d[3],bb2d[1]+bb2d[2]))
# 这里用了2d的bounding box对样本图片进行裁剪
img.convert('RGB')
return img
except:
print("Cannot read image:{}".format(path))
class customData(Dataset):
def __init__(self,path,part,mode,dataset='',data_transforms=None,loader=default_loader):
# 主要就是将样本图片的路径和对应的标签分别放到self.img_name和self.img_label这两个list中
self.img_name = []
self.img_label = []
self.bb2d = []
data = load_cache(path + 'dataset.pkl')
classification = load_cache(path + 'classification_splits.pkl')
list = classification[part][mode]
for i in range(len(list)):
vehicle_id = list[i][0]
class_id = list[i][1]
instances = data['samples'][vehicle_id]['instances']
for j in range(len(instances)):
image = path + 'images/' + instances[j]['path']
bb = instances[j]['2DBB']
self.bb2d.append(bb)
self.img_name.append(image)
self.img_label.append(class_id)
self.data_transforms = data_transforms
self.dataset = dataset
self.loader = loader
def __len__(self):
return len(self.img_name)
def __getitem__(self, item):
img_name = self.img_name[item]
label = self.img_label[item]
bb2d = self.bb2d[item]
im