最近需要训练一个有200多类的图片分类网络,搜了一遍,发现居然没有很合适用的开源项目,于是自己简单撸了一个轮子,项目地址: https://github.com/xuduo35/imgcls_pytorch。支持如下backbone:
- alexnet
- resnet18,resnet34,resnet50,resnet101, resnet152, resnext101_32x4d, resnext101_64x4d
- vgg11_bn, vgg16_bn
- densenet121, densenet169, densenet161
- inceptionv3, inceptionv4, inceptionresnetv2, bninception
- xception, xception_att
- dpn98, dpn107, dpn131
- senet154, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d
- pnasnet5large
- polynet
- efficientnet
使用简便,第一步是按如下格式准备数据集,
- your_dataset_directory
- class1
- 1.jpg
- 2.jpg
- class2
- 1.jpg
- 2.jpg
- ...
- ...
- class1
自定义一个Dataset,实现如下