Swin Transformer Faster R-CNN 目标检测
1 环境
- 如果之前已经创建了 Swin Transformer Object Detection 项目所需的环境的话,可以直接使用,但是会对后面再训练Swin Transformer Object Detection 造成影响(因为mmdetection工程需要对mmdet的版本进行更改才能使用),所以建议再创建一个新的环境给mmdetection使用,或者直接clone一份之前的环境(推荐)。
- 克隆环境的方式为:
conda create -n conda-env2 --clone conda-env1
conda-env2
为新创建的环境(从其他环境clone来的)
conda-env1
为之前已经有的环境
注:克隆环境需要一段时间,请耐心等待。这样后面我们mmdetection的工程所使用的环境就是新clone的这个。clone 成功后按照下面步骤操作:- 在IDE中配置项目所使用的虚拟环境为我们新克隆的
- 进如到虚拟环境后,在mmdetection的项目目录下执行
python setup.py develop
,此时确定 mmdet被换成 2.23.0版本。
2 代码
2.1 configs/swin
在configs/swin 目录下新建文件:faster_rcnn_swin_t-p4-w7_fpn_3x_coco.py
文件内容如下:
注意:训练的epoch在这个文件中改,我直接设置成了50,大家根据需要修改。
_base_ = [
'../_base_/models/faster_rcnn_swin_fpn.py',
'../_base_/datasets/faster_rcnn_coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.0001,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))
lr_config = dict(warmup_iters=1000, step=[27, 33])
runner = dict(type='EpochBasedRunner', max_epochs=36)
2.2 configs/base/models
在 configs/base/models 下新建文件:faster_rcnn_swin_fpn.py
文件内容如下:
注意: num_classes
需要根据你数据集的类别进行更改
# model settings
pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth'
model = dict(
type='FasterRCNN',
backbone=dict(
type='SwinTransformer',
embed_dims=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
patch_norm=True,
out_indices=(0, 1, 2, 3),
with_cp=