代码仓库:https://github.com/open-mmlab/mmclassification
文档教程:https://mmclassification.readthedocs.io/en/latest/
一、安装
1.创建环境
加载 anaconda ,创建一个 python 3.8 的环境。
# 创建 python=3.8 的环境
conda create --name mmclassification python=3.8
# 激活环境
conda activate mmclassification
安装torch
查看已安装的cuda版本
nvcc --version
# Cuda compilation tools, release 11.3, V11.3.58
根据已有的cuda版本安装torch
pytorch下载地址:
https://pytorch.org/get-started/previous-versions/
# CUDA 11.3
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
安装 mmcv-full 模块,mmcv-full 模块安装时候需要注意 torch 和 cuda 版本。
在安装 mmcv-full 之前,请确保 PyTorch 已经成功安装在环境中,可以参考 PyTorch 官方安装文档。可使用以下命令验证:
python -c 'import torch;print(torch.__version__)'
使用 mim 安装
# 安装 mim
pip install -U openmim
# 安装基础库 mmcv 完整版
mim install mmcv-full
安装mmdetection
# 源码安装 mmdet
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
mim install -v -e .
安装mmclassification
# 源码安装 mmcls
cd ..
git clone https://github.com/open-mmlab/mmclassification.git
cd mmclassification
mim install -v -e .
2.验证安装
为了验证 MMClassification 的安装是否正确,我们提供了一些示例代码来执行模型推理。
第 1 步 我们需要下载配置文件和模型权重文件
mim download mmcls --config resnet50_8xb32_in1k --dest .
第 2 步 验证示例的推理流程
如果你是从源码安装的 mmcls,那么直接运行以下命令进行验证:
python demo/image_demo.py demo/demo.JPEG resnet50_8xb32_in1k.py resnet50_8xb32_in1k_20210831-ea4938fc.pth --device cpu
你可以看到命令行中输出了结果字典,包括 pred_label,pred_score 和 pred_class 三个字段。
3.数据集
flower 数据集包含 5 种类别的花卉图像:雏菊 daisy 588张,蒲公英 dandelion 556张,玫瑰 rose 583张,向日葵 sunflower 536张,郁金香 tulip 585张。
数据集下载链接:国际网:https://www.dropbox.com/s/snom6v4zfky0flx/flower_dataset.zip?dl=0国内网:https://pan.baidu.com/s/1RJmAoxCD_aNPyTRX6w97xQ 提取码: 9x5u
3.1划分数据集
将数据集按照 8:2 的比例划分成训练和验证子数据集,并将数据集整理成 ImageNet的格式
将训练子集和验证子集放到 train 和 val 文件夹下。
创建并编辑标注文件将所有类别的名称写到 classes.txt 中,每行代表一个类别。
tulip
dandelion
daisy
sunflower
rose
生成训练(可选)和验证子集标注列表 train.txt 和 val.txt ,每行应包含一个文件名和其对应的标签。如下,可将处理好的数据集迁移到 mmclassification/data 文件夹下。
数据集划分代码 split_data.py 如下,执行:
python split_data.py [源数据集路径] [目标数据集路径]
import os
import sys
import shutil
import numpy as np
def load_data(data_path):
count = 0
data = {}
for dir_name in os.listdir(data_path):
dir_path = os.path.join(data_path, dir_name)
if not os.path.isdir(dir_path):
continue
data[dir_name] = []
for file_name in os.listdir(dir_path):
file_path = os.path.join(dir_path, file_name)
if not os.path.isfile(file_path):
continue
data[dir_name].append(file_path)
count += len(data[dir_name])
print("{} :{}".format(dir_name, len(data[dir_name])))
print("total of image : {}".format(count))
return data
def copy_dataset(src_img_list, data_index, target_path):
target_img_list = []
for index in data_index:
src_img = src_img_list[index]
img_name = os.path.split(src_img)[-1]
shutil.copy(src_img, target_path)
target_img_list.append(os.path.join(target_path, img_name))
return target_img_list
def write_file(data, file_name):
if isinstance(data, dict):
write_data = []
for lab, img_list in data.items():
for img in img_list:
write_data.append("{} {}".format(img, lab))
else:
write_data = data
with open(file_name, "w") as f:
for line in write_data:
f.write(line + "\n")
print("{} write over!".format(file_name))
def split_data(src_data_path, target_data_path, train_rate=0.8):
src_data_dict = load_data(src_data_path)
classes = []
train_dataset, val_dataset = {}, {}
train_count, val_count = 0, 0
for i, (cls_name, img_list) in enumerate(src_data_dict.items()):
img_data_size = len(img_list)
random_index = np.random.choice(img_data_size, img_data_size,replace=False)
train_data_size = int(img_data_size * train_rate)
train_data_index = random_index[:train_data_size]
val_data_index = random_index[train_data_size:]
train_data_path = os.path.join(target_data_path, "train", cls_name)
val_data_path = os.path.join(target_data_path, "val", cls_name)
os.makedirs(train_data_path, exist_ok=True)
os.makedirs(val_data_path, exist_ok=True)
classes.append(cls_name)
train_dataset[i] = copy_dataset(img_list, train_data_index,train_data_path)
val_dataset[i] = copy_dataset(img_list, val_data_index, val_data_path)
print("target {} train:{}, val:{}".format(cls_name,len(train_dataset[i]), len(val_dataset[i])))
train_count += len(train_dataset[i])
val_count += len(val_dataset[i])
print("train size:{}, val size:{}, total:{}".format(train_count, val_count, train_count + val_count))
write_file(classes, os.path.join(target_data_path, "classes.txt"))
write_file(train_dataset, os.path.join(target_data_path, "train.txt"))
write_file(val_dataset, os.path.join(target_data_path, "val.txt"))
def main():
src_data_path = sys.argv[1]
target_data_path = sys.argv[2]
split_data(src_data_path, target_data_path, train_rate=0.8)
if __name__ == '__main__':
main()
二、代码教学
pytorch官方demo:
https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
OpenMMLab 项目中的重要概念——配置文件
作业
2376

被折叠的 条评论
为什么被折叠?



