前言
NTS-Net(Neural Task Scheduler Network)是一种针对多任务学习设计的深度学习模型,它特别适用于细粒度分类任务,这类任务涉及到将目标划分为大量细微差别的类别。例如,在动物分类中区分不同品种的鸟类或犬类。细粒度分类由于类别间的相似性较高,因此对特征提取和区分能力要求很高。
一、环境配置
创建专属环境
conda create -n ntsnet python=3.9
激活环境
conda activate ntsnet
安装 Pytorch GPU 环境
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple "torch-1.13.0+cu116-cp39-cp39-win_amd64.whl"
pip install "torchvision-0.14.1+cu117-cp37-cp37m-win_amd64.whl" -i https://pypi.tuna.tsinghua.edu.cn/simple
二、代码测试
数据集
使用的数据集是 Stanford Dogs Dataset
超过20000张120种狗的图片,斯坦福狗数据集包含了来自世界各地的120种狗的图片。此数据集是使用ImageNet中的图像和注释构建的,用于细粒度图像分类任务。它最初是为精细纹理图像分类而收集的,这是一个具有挑战性的问题,因为某些狗的特征几乎相同,或者颜色和年龄不同。
●类别数量:120
●图像数量:20580
●标注:类标签、边界框
划分数据集
def split_data(split_path):
f_train = open(split_path + "/train.txt", "w", encoding="utf-8")
f_val = open(split_path + "/val.txt", "w", encoding="utf-8")
f_label = open(split_path + "/label.txt", "w", encoding="utf-8")
for name in tqdm(os.listdir(os.path.join(split_path, 'images')), desc="process data",
total=len(os.listdir(os.path.join(split_path, 'images')))):
file = [os.path.joi