摘要:
记录MindSpore AI框架使用ResNet50迁移学习方法对ImageNet狼狗图片分类的过程、步骤。包括环境准备、下载数据集、数据集加载、构建模型、固定特征训练、训练评估和模型预测等。
一、概念
迁移学习的方法
在大数据集上训练得到预训练模型
初始化网络权重参数
固定特征提取器
应用于特定任务
ImageNet数据集中的狼和狗图像进行分类。
二、环境准备
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore
输出:
Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by:
三、数据准备
1.下载数据集
ImageNet
狼与狗每个分类
120张训练图像
30张验证图像
download接口
下载数据集
自动解压到当前目录下
from download import download
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"
download(dataset_url, "./datasets-Canidae", kind="zip", replace=True)
输出:
Creating data folder...
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip (11.3 MB)
file_sizes: 100%|███████████████████████████| 11.9M/11.9M [00:00<00:00, 140MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./datasets-Canidae
'./datasets-Canidae'
数据集目录结构:
datasets-Canidae/data/
└── Canidae
├── train
│ ├── dogs
│ └── wolves
└── val
├── dogs
└── wolves
2.加载数据集
mindspore.dataset.ImageFolderDataset接口
加载数据集
图像增强操作。
定义执行过程的输入参数:
batch_size = 18 # 批量大小
image_size = 224 # 训练图像空间大小
num_epochs = 5 # 训练周期数
lr = 0.001 # 学习率
momentum = 0.9 # 动量
workers = 4 # 并行线程个数
import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
# 数据集目录路径
data_path_train = "./datasets-Canidae/data/Canidae/train/"
data_path_val = "./datasets-Canidae/data/Canidae/val/"
# 创建训练数据集
def create_dataset_canidae(dataset_path, usage):
"""数据加载"""
data_set = ds.ImageFolde