pytorch dataloader 自定义数据读取,resnet-50在boxcars数据集上

本文介绍了如何在PyTorch中自定义Dataloader来读取boxcars数据集,该数据集存储为.pkl文件。首先,我们创建一个读取.pkl文件的函数,接着定义Dataloader类。然后,实现训练(train)函数,最后进行测试(test)部分。使用ResNet-50模型进行图像识别。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

boxcars数据集放在pkl文件中,首先需要读取 .pkl文件,定义一个读取函数。

import pickle

def load_cache(path, encoding="latin-1", fix_imports=True):
    """
    encoding latin-1 is default for Python2 compatibility
    """
    with open(path, "rb") as f:
        return pickle.load(f, encoding=encoding, fix_imports=True)

下面放dataloader类

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable

from torch.utils.data import Dataset
from torchvision import models,transforms
import os
import time
import pickle

from PIL import Image


def default_loader(path, bb2d):
    try:
        img = Image.open(path)
        img.crop((bb2d[0],bb2d[1],bb2d[0]+bb2d[3],bb2d[1]+bb2d[2]))
        # 这里用了2d的bounding box对样本图片进行裁剪
        img.convert('RGB')
        return img
    except:
        print("Cannot read image:{}".format(path))

class customData(Dataset):
    def __init__(self,path,part,mode,dataset='',data_transforms=None,loader=default_loader):
# 主要就是将样本图片的路径和对应的标签分别放到self.img_name和self.img_label这两个list中
        self.img_name = []
        self.img_label = []
        self.bb2d = []
        data = load_cache(path + 'dataset.pkl')
        classification = load_cache(path + 'classification_splits.pkl')
        list = classification[part][mode]
        for i in range(len(list)):
            vehicle_id = list[i][0]
            class_id = list[i][1]
            instances = data['samples'][vehicle_id]['instances']
            for j in range(len(instances)):
                image = path + 'images/' + instances[j]['path']
                bb = instances[j]['2DBB']
                self.bb2d.append(bb)
                self.img_name.append(image)
                self.img_label.append(class_id)

        self.data_transforms = data_transforms
        self.dataset = dataset
        self.loader = loader
       
    def __len__(self):  
        return len(self.img_name)

    def __getitem__(self, item):
        img_name = self.img_name[item]
        label = self.img_label[item]
        bb2d = self.bb2d[item]
        im
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值