你使用了 PyTorch 自带的 `torchvision.datasets.CocoDetection` 来加载 COCO 格式数据集,这是一个非常简洁且推荐的方式。相比手动实现 `Dataset` 类,`CocoDetection` 已经封装好了图像与标注的读取逻辑,并直接返回 PIL 图像和 COCO 注解列表。
下面我将 **完整补充你的代码**,包括:
- 使用 `CocoDetection` 加载数据
- 正确的预处理与数据增强(支持目标检测)
- 自定义训练循环
- 模型微调配置(修改分类头以适配自定义类别数)
- DataLoader 构建注意事项
- 训练过程中的损失打印
---
## ✅ 完整可运行代码:基于 `CocoDetection` 的 Faster R-CNN 训练
```python
import torch
import torchvision
from torchvision.datasets import CocoDetection
from torchvision.transforms import transforms
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import DataLoader
import torch.optim as optim
from collections import defaultdict
import time
import os
# -------------------------------
# 1. 设置设备
# -------------------------------
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {device}")
# -------------------------------
# 2. 自定义转换函数(支持图像+目标同时增强)
# -------------------------------
# 注意:ToTensor 会改变 bbox 等结构,所以我们需要一个包装函数
def collate_fn(batch):
return tuple(zip(*batch))
# 推荐使用 Albumentations 进行联合增强(但这里先用最简方式)
# 我们只做 ToTensor 并确保 target 不被自动转换
dataset = CocoDetection(
root="dataset/images/train",
annFile="coco_annotations.json",
transform=transforms.ToTensor() # 只对图像应用,CocoDetection 会保留 target 原始格式
)
data_loader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
num_workers=4,
collate_fn=collate_fn # 关键!因为每张图的目标数量不同
)
```
### 🔧 修改模型以适配自定义类别数
```python
def get_model(num_classes):
# 加载预训练 Faster R-CNN
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# 获取分类器输入特征数
in_features = model.roi_heads.box_predictor.cls_score.in_features
# 替换预测头
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
```
### 🛠️ 转换 COCO 注解为模型所需格式(边界框 + label)
```python
def convert_coco_annotation(annotations):
"""将原始 COCO 注解转换为 Faster R-CNN 所需的 tensor 格式"""
boxes = []
labels = []
area = []
iscrowd = []
for ann in annotations:
# COCO bbox 格式是 [x, y, width, height]
x_min, y_min, w, h = ann['bbox']
boxes.append([x_min, y_min, x_min + w, y_min + h]) # 转为 [x_min, y_min, x_max, y_max]
labels.append(ann['category_id'])
area.append(ann['area'])
iscrowd.append(ann['iscrowd'])
return {
"boxes": torch.tensor(boxes, dtype=torch.float32),
"labels": torch.tensor(labels, dtype=torch.int64),
"area": torch.tensor(area, dtype=torch.float32),
"iscrowd": torch.tensor(iscrowd, dtype=torch.uint8)
}
```
### 🚀 训练函数
```python
def train_one_epoch(model, optimizer, data_loader, device, epoch):
model.train()
loss_dict_sum = defaultdict(float)
total_batches = len(data_loader)
start_time = time.time()
for i, (images, targets) in enumerate(data_loader):
# 将图像转到设备
images = [img.to(device) for img in images]
# 转换每个 target 为模型期望的格式
converted_targets = []
for target in targets:
converted_targets.append({k: v.to(device) for k, v in convert_coco_annotation(target).items()})
# 前向传播
loss_dict = model(images, converted_targets)
losses = sum(loss for loss in loss_dict.values())
# 反向传播
optimizer.zero_grad()
losses.backward()
optimizer.step()
# 统计损失
for key, value in loss_dict.items():
loss_dict_sum[key] += value.item()
if i % 10 == 0:
print(f"Epoch [{epoch}], Batch [{i}/{total_batches}], Loss: {losses.item():.4f}")
avg_loss = {k: v / len(data_loader) for k, v in loss_dict_sum.items()}
print(f"Epoch {epoch} completed in {time.time() - start_time:.2f}s, Avg Loss: {sum(avg_loss.values()):.4f}")
return avg_loss
```
### 🏁 主函数
```python
def main():
num_classes = 10 # 示例:你的数据集中有 9 个物体类别 + 背景(注意:COCO 类别 ID 不一定连续)
# 获取模型
model = get_model(num_classes)
model.to(device)
# 优化器
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
# 学习率调度器(可选)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
num_epochs = 10
for epoch in range(num_epochs):
train_one_epoch(model, optimizer, data_loader, device, epoch)
lr_scheduler.step() # 更新学习率
# 保存模型
torch.save(model.state_dict(), "faster_rcnn_custom.pth")
print("Training finished. Model saved.")
if __name__ == "__main__":
main()
```
---
## 🔍 关键点解释
| 技术点 | 解释 |
|-------|------|
| `collate_fn=collate_fn` | 因为每张图像的对象数量不同,PyTorch 默认 `DataLoader` 无法堆叠 bbox 张量,必须用 `zip(*batch)` 分离图像和目标 |
| `convert_coco_annotation()` | 将字典形式的 COCO 注解(如 `{'bbox': [...], 'category_id': 1}`)转换为模型所需的 `{"boxes": ..., "labels": ...}` Tensor 格式 |
| `model(images, targets)` | `targets` 必须包含 `boxes`, `labels` 等字段,且在 GPU 上 |
| `FastRCNNPredictor` 替换 | 允许模型输出你自己的类别数(例如不是 91 类) |
---
## ✅ 提示与建议
1. **类别 ID 映射问题**:
- 如果你的 `category_id` 是非连续的(比如 `[3, 5, 8]`),你需要建立映射表将其压缩为 `[1, 2, 3]`
- 或者使用 `categories = coco.loadCats(coco.getCatIds())` 查看实际类别名
2. **添加数据增强(强烈建议)**
```bash
pip install albumentations
```
然后使用 [Albumentations](https://albumentations.ai/docs/examples/pytorch_detection/) 实现 `bbox` 同时增强。
3. **验证集构建方法相同**
```python
val_dataset = CocoDetection(root="val/images", annFile="val.json", transform=ToTensor())
```
4. **性能监控**
- 使用 `TensorBoard` 或 `tqdm` + 日志记录损失变化
---