【KnowledgeBase】基于Pytorch建立一个自定义的目标检测DataLoader


前言

代码和文件夹免费公开,学习自取。链接!链接!链接!

本文介绍如何通过torch建立一个自己的目标检测数据集DataLoader。以WIDERFACE的部分图片与YOLO格式标注为例。本文分为以下4步介绍建立DataLoader的整体思路,具体还是要根据自己的数据集File格式进行调整:

  1. 数据集File格式介绍
  2. 代码整体思路及展示
  3. 代码分块介绍
  4. 代码测试

一、数据集File格式介绍

我们使用了4张WIDERFACE中的图片以及YOLO格式的标签来进行说明,整体的数据结构如下图,其中用来测试使用的代码文件DIY_DataLoader.ipynb也在同一目录下。
在这里插入图片描述

  1. imgaes中存放.jpg图片;
    在这里插入图片描述

  2. labels中存放.txt的YOLO格式标注文件;
    在这里插入图片描述
    在这里插入图片描述

  3. DIY_DataLoader.ipynb是测试用的代码文件;

  4. train.txt中罗列了图片的路径。
    在这里插入图片描述


二、代码整体思路及展示

2.1 代码整体思路

自己的DIY的DataLoader需要重写其中的一些方法,主要包括:__int____len____getitem__

  • __int__中保存一些数据集相关信息,最终为了得到:每一张图片路径、每一个标注路径、对图片进行的transform;
  • __len__为了得到一共有多少张图片数量;
  • __getitem__为了得到其中某一张图片的[image_array, gt_bbox]

2.2 代码整体展示

import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
class WIDERFACE(Dataset):
    def __init__(self, root_dir, image_file, ann_file, ann_txt, transform=None):
        self.root_dir = root_dir        # Root file
        self.image_file = image_file    # Image file
        self.ann_file = ann_file        # Annotations file

        self.imagenames = self.load_imgnames(ann_txt)

        # Load imgs/annos file
        self.imgs = [f'{x}.jpg' for x in [os.path.join(root_dir, image_file, image) for image in self.imagenames]]
        self.annos = [f'{x}.txt' for x in [os.path.join(root_dir, ann_file, image) for image in self.imagenames]]

        self.transform = transform
    
    def __len__(self):
        return len(self.imagenames)
    
    def __getitem__(self, idx):
        image = np.array(Image.open(self.imgs[idx]).getdata())
        with open(self.annos[idx]) as f:
            gt_bbox = [x.strip('\n').split('/')[-1] for x in f.readlines()] # x, y, width, height
        sample = {'img': image, 'gt_bbox': gt_bbox}
        if self.transform:
            sample = self.transform(sample)
        return sample
    
    def load_imgnames(self, ann_txt):
        with open(ann_txt) as f:
            samples = [x.strip('\n').split('/')[-1] for x in f.readlines()]
            names = [x.split('.')[0] for x in samples]
        return names

三、代码分块介绍

这里将一块块地详细介绍下类中每一个方法的内容。

3.1 def load_imgnames

这块代码最终为了读取下每一张图片的名称,在我们的文件夹中,它的输入为train.txt

	def load_imgnames(self, ann_txt):
        with open(self, ann_txt) as f:
            samples = [x.strip('\n').split('/')[-1] for x in f.readlines()]
            names = [x.split('.')[0] for x in samples]
        return names

简单测试一下,就是
在这里插入图片描述

3.2 def _init_

这一块主要是保存并告诉一下DataLoader,图片文件的具体路径、图片标注框的具体路径、用了什么transform方法。

	def __init__(self, root_dir, image_file, ann_file, ann_txt, transform=None):
        self.root_dir = root_dir        # Root file         './'
        self.image_file = image_file    # Image file        'images/'
        self.ann_file = ann_file        # Annotations file  'labels/'

        self.imagenames = self.load_imgnames(ann_txt)   # 得到了每张图片的名称

        # 基于self.imagenames,得到每张图片的 imgs/annos 具体的路径
        self.imgs = [f'{x}.jpg' for x in [os.path.join(root_dir, image_file, image) for image in self.imagenames]]
        self.annos = [f'{x}.txt' for x in [os.path.join(root_dir, ann_file, image) for image in self.imagenames]]

        self.transform = transform

3.3 def _len_

self.imagenames是一个保存了所有图片名称的List,故使用len()方法可以知道一共有多少张图片。当然self.imagenames也可以替换成self.imgs或者self.annos,效果是一样的。

	def __len__(self):
        return len(self.imagenames)

3.4 def _getitem_

    def __getitem__(self, idx):
        # 根据图片路径打开图片并转化成np.array格式
        image = np.array(Image.open(self.imgs[idx]).getdata())
        # 保存图片对应的gt_bbox[x, y, width, height]
        with open(self.annos[idx]) as f:
            gt_bbox = [x.strip('\n').split('/')[-1] for x in f.readlines()]
        # 使用dict对一张图片的信息进行包装
        sample = {'img': image, 'gt_bbox': gt_bbox}
        if self.transform:
            sample = self.transform(sample)
        return sample

四、代码测试

我们使用这个由4张图片组成的数据集进行一下DIY_WIDERFACE这个DataLoader的代码测试。

root_file = './'
image_file = 'images/'
ann_file = 'labels/'
ann_txt = './train.txt'

test = DIY_WIDERFACE(root_file, image_file, ann_file, ann_txt)
  1. __init__方法中储藏的一些信息展示,如下:

在这里插入图片描述

  1. __len__方法表示的图片数量,如下:

在这里插入图片描述

  1. __getitem__方法展示某一张图片的信息,包括图片的数组信息、gt_bbox,如下:

在这里插入图片描述


总结

本文就简单地带大家理解下DataLoader的构造思路。
欢迎批评指正。

<think>嗯,用户想了解如何使用PyTorch来实现AI,特别是要求编程和推理能力强的。首先,我需要明确他们的需求。他们可能已经对深度学习有一定的了解,但需要具体的指导和示例。可能他们想开发一个需要复杂模型和高效计算的AI应用,比如自然语言处理或者计算机视觉项目。 接下来,我应该考虑PyTorch的优势在哪里。PyTorch的动态计算图对于模型调试和灵活性很有帮助,这对提高编程效率很重要。然后,需要强调PyTorch的生态系统,比如TorchVision、TorchText这些库,可以加速开发过程。 用户提到“推理能力强”,这可能指的是模型在推理时的性能,包括速度和准确性。这里需要介绍如何优化模型,比如使用TorchScript进行序列化,或者使用ONNX导出模型到其他平台。另外,量化(Quantization)和半精度训练(FP16)也是提升推理速度的有效方法。 然后,编程能力方面,可能需要给出一些代码示例,展示PyTorch的基础模块,比如定义模型、训练循环、数据加载等。同时,结合具体的案例,如图像分类或文本生成,可以让用户更直观地理解。 还要考虑用户可能的背景。如果他们已经有PyTorch基础,可能需要更深入的内容,比如自定义层、混合精度训练,或者分布式训练。但如果他们是初学者,可能需要从安装和基本概念讲起。不过用户提到“编程、推理能力强”,可能更倾向于中高级内容。 另外,用户可能需要了解如何利用PyTorch的高级特性,如自动微分、自定义损失函数,或者集成最新的研究模型,比如Transformer架构。这些都能提升模型的推理能力。 最后,应该总结PyTorch的资源,比如官方文档、社区支持,以及学习路径,帮助用户进一步深入学习和解决问题。</think>PyTorch一个基于 Python 的深度学习框架,以其动态计算图、灵活性和高效的 GPU 加速能力著称,广泛应用于学术研究和工业界。以下是使用 PyTorch 实现 AI 的关键方向,结合编程和推理能力的实践指南: --- ### **1. PyTorch 核心优势** - **动态计算图(Dynamic Computation Graph)** 通过 `autograd` 实现自动微分,支持动态调整计算流程(如循环、条件分支),适合复杂模型(如 NLP 中的 Transformer)。 - **GPU 加速与并行化** 使用 `torch.cuda` 管理 GPU 计算,结合 `DataParallel` 或 `DistributedDataParallel` 实现多卡训练。 - **丰富的生态系统** TorchVision(CV)、TorchText(NLP)、TorchAudio(语音)等工具库加速开发。 --- ### **2. 提升编程能力的实践技巧** #### **(1) 模型定义与模块化** ```python import torch import torch.nn as nn class CustomModel(nn.Module): def __init__(self): super().__init__() self.conv = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3), nn.ReLU(), nn.MaxPool2d(2) ) self.fc = nn.Linear(64*13*13, 10) # 根据输入尺寸调整 def forward(self, x): x = self.conv(x) x = x.view(x.size(0), -1) return self.fc(x) ``` #### **(2) 高效数据加载与预处理** ```python from torchvision import transforms, datasets from torch.utils.data import DataLoader transform = transforms.Compose([ transforms.Resize(256), transforms.RandomCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) dataset = datasets.ImageFolder("data/train", transform=transform) dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4) ``` #### **(3) 训练循环优化** ```python model = CustomModel().cuda() criterion = nn.CrossEntropyLoss() optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) for epoch in range(100): for inputs, labels in dataloader: inputs, labels = inputs.cuda(), labels.cuda() outputs = model(inputs) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) # 梯度裁剪 optimizer.step() ``` --- ### **3. 提升推理性能的关键技术** #### **(1) 模型轻量化** - **量化(Quantization)**:降低数值精度(FP32 → INT8) ```python quantized_model = torch.quantization.quantize_dynamic( model, {nn.Linear}, dtype=torch.qint8 ) ``` - **知识蒸馏(Knowledge Distillation)**:用大模型指导小模型训练。 #### **(2) 部署优化** - **TorchScript**:将模型转换为静态图,支持跨平台部署 ```python scripted_model = torch.jit.script(model) scripted_model.save("model.pt") ``` - **ONNX 导出**:与其他框架(如 TensorFlow)互操作 ```python torch.onnx.export(model, dummy_input, "model.onnx") ``` #### **(3) 高性能推理技巧** - **半精度(FP16)**:使用 NVIDIA Apex 或 PyTorch 原生 AMP(Automatic Mixed Precision) ```python scaler = torch.cuda.amp.GradScaler() with torch.amp.autocast(): outputs = model(inputs) ``` - **TensorRT 集成**:针对 NVIDIA GPU 优化推理速度。 --- ### **4. 复杂推理任务的案例** #### **(1) 自然语言处理(NLP)** ```python from transformers import AutoModel, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") model = AutoModel.from_pretrained("bert-base-uncased").cuda() inputs = tokenizer("Hello, PyTorch!", return_tensors="pt").to("cuda") outputs = model(**inputs) # 输出词向量 ``` #### **(2) 生成对抗网络(GAN)** ```python # 生成器 generator = nn.Sequential( nn.Linear(100, 256), nn.LeakyReLU(0.2), nn.Linear(256, 784), nn.Tanh() ) # 判别器 discriminator = nn.Sequential( nn.Linear(784, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid() ) # Wasserstein GAN 损失函数 def wgan_loss(real_pred, fake_pred): return torch.mean(real_pred) - torch.mean(fake_pred) ``` --- ### **5. 学习资源与工具** - **官方教程**:[PyTorch Tutorials](https://pytorch.org/tutorials/) - **高级库**: - Hugging Face Transformers(NLP) - Detectron2(目标检测) - PyTorch Lightning(训练流程抽象) - **调试工具**: - `torchviz` 可视化计算图 - `torch.profiler` 性能分析 --- 通过结合 PyTorch 的灵活性与上述优化技术,可以高效实现从研究到部署的完整 AI 流水线。建议从简单模型(如 MNIST 分类)开始,逐步深入复杂任务(如目标检测、文本生成),同时关注模型压缩和推理加速技术。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Prymce-Q

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值