
图像分类通用简易框架 - ClassifySI
github:https://github.com/Giperx/ClassifySI
新能源汽车图像分类-ClassifySI框架应用实践:https://blog.youkuaiyun.com/m0_51657509/article/details/145300821
实现一个简易框架,用于进行深度学习图像分类的入门探索。
可通过提前准备或爬取的数据集和自定义的网络结构进行训练及测试。
准备
- 数据集
- 环境
- 训练网络搭建
数据集
- 准备数据集
-
通过开源数据集网站或数据集官网进行获取,如paperswithcode、paddlepaddle等。
-
或自行准备数据。项目提供crawler.py简易爬虫通过搜索引擎获取图片。
-
python crawler.py -q [keyword] -o [Path] -n [numberImages]
-
options:
- -q --query,搜索关键词
- -o --output,图片保存目录,
./raw_data/
+[outfoldname]
- -n --number,需要获取的图片数量
-
爬取后会有部分图片不符合要求,可能需要进一步手动剔除。
-
- 预处理数据
-
初始数据文件结构:
-
raw_data/ ├── [class1Name] │ ├── [1].jpg │ ├── ... │ ├── [n].jpg ├──... ├── [classnName] │ ├── ...
-
-
项目提供dataAugmSplit.py对数据进行增强扩充和划分为训练集、验证集和测试集。
-
python dataAugmSplit.py -raw [Path] -out [Path] -tr [float] -val [float] -te [float] -tr_times [number] -val_times [number] -te_times [number] -p [float]
-
options:
- -raw --raw_dir,初始数据所在位置,default=‘./raw_data’
- -out --output_dir,输出处理后数据位置,default=‘./data’
- -tr --train_ratio,训练集所占初始数据比例,default=0.8
- -val --val_ratio,验证集所占初始数据比例,default=0.1
- -te --test_ratio,测试集所占初始数据比例,default=0.1
- 划分使用
train_test_split()
,随机种子固定为random_state=42
- 划分使用
- -tr_times --augment_times_train,训练集增强次数,default=3
- -val_times --augment_times_val,验证集增强次数,default=1
- -test_times --augment_times_test,测试集增强次数,default=0
- -p --prob_aa,使用 AutoAugment 策略增强的概率,default=0.45
-
-
如果初始数据集过大不想增强扩充,可设置三个增强次数参数为 0
-
划分后文件结构:
-
data/ ├── test │ ├── [class1Name] │ ├── [1].jpg │ ├── ... │ ├── [n].jpg │ ├──... │ ├── [classnName] │ ├── ... ├── train │ ├── ... ├── val │ ├── ...
-
example:
python crawler.py -q 孙悟空 -o wukong -n 10
环境
- 安装 python。
conda create -n classifysi python=3.8 -y
conda activate classifysi
- 安装 pytorch、torchvision,下为示例。
# 换源(可选):
#conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
#conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
#conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
# for linux
#conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
# for legacy win-64
#conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/peterjc123/
# 安装pytorch:
# OnlyCPU
# conda install pytorch torchvision torchaudio cpuonly -c pytorch -y
# pip3 install torch torchvision torchaudio
# GPU: 在https://pytorch.org/get-started/previous-versions/中查看符合cuda版本进行安装
#example1: conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3
# pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
#example2: pip3 install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
进入项目目录,安装其它依赖。
pip3 install -r requirements.txt
训练网络搭建
在modelForClassify.py中定义模型结构,或使用注释中提供的两个预训练网络模型:ResNet-18、MobileNetV2
class Classifier(nn.Module):
def __init__(self, num_classes):
super(Classifier, self).__init__()
# 定义模型
def forward(self, x):
# 定义前向传播
注: 在训练阶段,resize
为 $ 448 \times 448 $ 。
训练及测试
训练
下面三行命令等价(默认参数)
-
python train.py
-
python train.py -e 50 -tr ./data/train -val ./data/val -g 0 -batch 32 1 -sd -1
-
python train.py --epochs 50 --train_dir ./data/train --val_dir ./data/val --gpus 0 --batch_size 32 1 --seed -1
-
options:
-
-e --epochs,训练轮数,default=50
-
-tr --train_dir,训练集路径,default=‘./data/train’
-
-val --val_dir,验证集路径,default=‘./data/val’
-
-g --gpus,使用 GPU 数量,default=0,0 表示使用 CPU 进行训练
-
-batch --batch_size,训练和验证时的批次大小,default=[32, 1]
-
-sd --seed,随机数种子,default=-1,-1 表示不进行随机数种子的固定
-
-resume --resume,模型权重文件路径,default=‘’。可从路径中加载权重继续训练
-
example:
-
python train.py -resume ./save_model/model.pth
-
-
训练结束时,log/
中保存输出日志信息;/pics
中保存训练的 loss、acc 曲线;./save_model/model.pth
中保存训练过程中达到 best_acc 的模型权重参数。
如果报错(Linux):
RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
可在命令行运行unset LD_LIBRARY_PATH
解决。参考issues
example:
测试
下面三行命令等价(默认参数)
-
python test.py
-
python test.py -t ./data/test -batch 1 -m ./save_model/model.pth
-
python test.py -test_dir ./data/test --batch_size 1 --model_path ./save_model/model.pth
-
options:
- -t --test_dir,测试集路径,default=‘./data/test’
- -batch --batch_size,测试时的批次大小,default=1
- -m --model_path,模型权重文件路径,default=‘./save_model/model.pth’
测试结束时,log/
中保存输出日志信息(模型准确率 Accuracy、F1 Score,各类精确率 Precision、召回率 Recall、特异度 Specificity、F1 Score,模型参数量,推理速度);/pics
中保存测试生成的混淆矩阵图;./save_model/model_all.model
保存整个模型(可导入Netron进行可视化)。