PyTorch/torchvision.datasets自带常用数据集【总结】

本文介绍了PyTorch框架中可用的各种数据集,包括MNIST、COCO、LSUN、ImageNet、CIFAR、STL10、SVHN和PhotoTour等,涵盖了从手写数字到复杂图像识别的不同领域。同时,文章提供了数据集的调用方式和一些常见问题的解决方案。

一、PyTorch环境

@PyTorch 1.0
@安装torchvision及报错的解决办法如下:
@下载命令1:pip3 install torchvision #可能会报错
@下载命令2:pip install --no-deps torchvision
@Linux下,可能需要sudo

二、PyTorch自带常用数据集列表

  • MNIST
    #一个手写数字数据集集,提供了60000+训练用例和10000个测试用例

    The MNIST database of handwritten digits, available from this page, has a training set of 60,000 examples, and a test set of 10,000 examples.

  • COCO
    #通常认为COCO是更倾向于图像分割的数据集

    COCO-Text is a new large scale dataset for text detection and recognition in natural images.
    Version 1.3 of the dataset is out!
    63,686 images, 145,859 text instances, 3 fine-grained text attributes.
    This dataset is based on the MSCOCO dataset.
    •Text localizations as bounding boxes
    •Text transcriptions for legible text
    •Multiple text instances per image
    •More than 63,000 images
    •More than 145,000 text instances
    •Text instances categorized into machine printed and handwritten text
    •Text instances categorized into legible and illegilbe text
    •Text instances categorized into English script and non-English script

  • Captions
    #望文生义,主要是标题测试用例,通常算在COCO中

  • Detection

  • #感知数据集,通常也包含在COCO中

  • LSUN
    #场景感知数据集(感觉很酷

    While there has been remarkable progress in the performance of visual recognition algorithms, the state-of-the-art models tend to be exceptionally data-hungry. Large labeled training datasets, expensive and tedious to produce, are required to optimize millions of parameters in deep network models.
    详细信息点击:

你使用了 PyTorch 自带的 `torchvision.datasets.CocoDetection` 来加载 COCO 格式数据集,这是一个非常简洁且推荐的方式。相比手动实现 `Dataset` 类,`CocoDetection` 已经封装好了图像与标注的读取逻辑,并直接返回 PIL 图像和 COCO 注解列表。 下面我将 **完整补充你的代码**,包括: - 使用 `CocoDetection` 加载数据 - 正确的预处理与数据增强(支持目标检测) - 自定义训练循环 - 模型微调配置(修改分类头以适配自定义类别数) - DataLoader 构建注意事项 - 训练过程中的损失打印 --- ## ✅ 完整可运行代码:基于 `CocoDetection` 的 Faster R-CNN 训练 ```python import torch import torchvision from torchvision.datasets import CocoDetection from torchvision.transforms import transforms from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torch.utils.data import DataLoader import torch.optim as optim from collections import defaultdict import time import os # ------------------------------- # 1. 设置设备 # ------------------------------- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') print(f"Using device: {device}") # ------------------------------- # 2. 自定义转换函数(支持图像+目标同时增强) # ------------------------------- # 注意:ToTensor 会改变 bbox 等结构,所以我们需要一个包装函数 def collate_fn(batch): return tuple(zip(*batch)) # 推荐使用 Albumentations 进行联合增强(但这里先用最简方式) # 我们只做 ToTensor 并确保 target 不被自动转换 dataset = CocoDetection( root="dataset/images/train", annFile="coco_annotations.json", transform=transforms.ToTensor() # 只对图像应用,CocoDetection 会保留 target 原始格式 ) data_loader = DataLoader( dataset, batch_size=2, shuffle=True, num_workers=4, collate_fn=collate_fn # 关键!因为每张图的目标数量不同 ) ``` ### 🔧 修改模型以适配自定义类别数 ```python def get_model(num_classes): # 加载预训练 Faster R-CNN model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) # 获取分类器输入特征数 in_features = model.roi_heads.box_predictor.cls_score.in_features # 替换预测头 model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model ``` ### 🛠️ 转换 COCO 注解为模型所需格式(边界框 + label) ```python def convert_coco_annotation(annotations): """将原始 COCO 注解转换为 Faster R-CNN 所需的 tensor 格式""" boxes = [] labels = [] area = [] iscrowd = [] for ann in annotations: # COCO bbox 格式是 [x, y, width, height] x_min, y_min, w, h = ann['bbox'] boxes.append([x_min, y_min, x_min + w, y_min + h]) # 转为 [x_min, y_min, x_max, y_max] labels.append(ann['category_id']) area.append(ann['area']) iscrowd.append(ann['iscrowd']) return { "boxes": torch.tensor(boxes, dtype=torch.float32), "labels": torch.tensor(labels, dtype=torch.int64), "area": torch.tensor(area, dtype=torch.float32), "iscrowd": torch.tensor(iscrowd, dtype=torch.uint8) } ``` ### 🚀 训练函数 ```python def train_one_epoch(model, optimizer, data_loader, device, epoch): model.train() loss_dict_sum = defaultdict(float) total_batches = len(data_loader) start_time = time.time() for i, (images, targets) in enumerate(data_loader): # 将图像转到设备 images = [img.to(device) for img in images] # 转换每个 target 为模型期望的格式 converted_targets = [] for target in targets: converted_targets.append({k: v.to(device) for k, v in convert_coco_annotation(target).items()}) # 前向传播 loss_dict = model(images, converted_targets) losses = sum(loss for loss in loss_dict.values()) # 反向传播 optimizer.zero_grad() losses.backward() optimizer.step() # 统计损失 for key, value in loss_dict.items(): loss_dict_sum[key] += value.item() if i % 10 == 0: print(f"Epoch [{epoch}], Batch [{i}/{total_batches}], Loss: {losses.item():.4f}") avg_loss = {k: v / len(data_loader) for k, v in loss_dict_sum.items()} print(f"Epoch {epoch} completed in {time.time() - start_time:.2f}s, Avg Loss: {sum(avg_loss.values()):.4f}") return avg_loss ``` ### 🏁 主函数 ```python def main(): num_classes = 10 # 示例:你的数据集中有 9 个物体类别 + 背景(注意:COCO 类别 ID 不一定连续) # 获取模型 model = get_model(num_classes) model.to(device) # 优化器 params = [p for p in model.parameters() if p.requires_grad] optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005) # 学习率调度器(可选) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1) num_epochs = 10 for epoch in range(num_epochs): train_one_epoch(model, optimizer, data_loader, device, epoch) lr_scheduler.step() # 更新学习率 # 保存模型 torch.save(model.state_dict(), "faster_rcnn_custom.pth") print("Training finished. Model saved.") if __name__ == "__main__": main() ``` --- ## 🔍 关键点解释 | 技术点 | 解释 | |-------|------| | `collate_fn=collate_fn` | 因为每张图像的对象数量不同,PyTorch 默认 `DataLoader` 无法堆叠 bbox 张量,必须用 `zip(*batch)` 分离图像和目标 | | `convert_coco_annotation()` | 将字典形式的 COCO 注解(如 `{'bbox': [...], 'category_id': 1}`)转换为模型所需的 `{"boxes": ..., "labels": ...}` Tensor 格式 | | `model(images, targets)` | `targets` 必须包含 `boxes`, `labels` 等字段,且在 GPU 上 | | `FastRCNNPredictor` 替换 | 允许模型输出你自己的类别数(例如不是 91 类) | --- ## ✅ 提示与建议 1. **类别 ID 映射问题**: - 如果你的 `category_id` 是非连续的(比如 `[3, 5, 8]`),你需要建立映射表将其压缩为 `[1, 2, 3]` - 或者使用 `categories = coco.loadCats(coco.getCatIds())` 查看实际类别名 2. **添加数据增强(强烈建议)** ```bash pip install albumentations ``` 然后使用 [Albumentations](https://albumentations.ai/docs/examples/pytorch_detection/) 实现 `bbox` 同时增强。 3. **验证集构建方法相同** ```python val_dataset = CocoDetection(root="val/images", annFile="val.json", transform=ToTensor()) ``` 4. **性能监控** - 使用 `TensorBoard` 或 `tqdm` + 日志记录损失变化 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值