一、detr介绍
DEtection TRansformer(DETR)是Facebook AI的研究者提出的Transformer的视觉版本,用于目标检测和全景分割。这是第一个将Transformer成功整合为检测pipeline中心构建块的目标检测框架。
1、代码地址Github:https://github.com/facebookresearch/detr
2、论文地址paper with code:
End-to-End Object Detection with Transformers
二、使用步骤
1、先将代码下载下来并在Pycharm中打开,创建一个虚拟环境,激活后点击terminal,输入
pip install -r requirements.txt
如果要安装cuda版本的torch和torchvision,可以在Pytorch官网搜索,我这边直接给出下载指令
pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118
这样,基本的环境就没什么问题了。
2、数据集:coco数据集
格式:
path/to/coco/
annotations/ # annotation json files
train2017/ # train images
val2017/ # val images
3、模型
4、训练模型
python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py --coco_path coco
模型结果:
训练十个epoch,每训练一个打印出参数,并保存在log.txt文件夹当中:
5、预测
预测脚本predict.py
内容如下:
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as T
torch.set_grad_enabled(False)
# COCO classes
CLASSES = [
'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse'