一、数据集定义
使用mindspore.dataset中的ImageFolderDataset接口加载图像分类数据集,ImageFolderDataset接口传入数据集文件上层目录,每个子目录分别放入不同类别的图像。使用python定义一个create_dataset函数用于创建数据集,在函数中使用mindspore.dataset.vision接口中的Decode、Resize、Normalize、HWC2CHW对图像进行解码、调整尺寸、归一化和通道变换预处理,其中Resize根据模型需要的图像大小进行设置,归一化操作可以通过设置mean和std约束范围。如将mean设置为[127.5,127.5,127.5],std设置为[255,255,255],可以将数据归一化到[0.5~0.5]范围内。
数据集加载:
import mindspore
import mindspore.nn as nn
from mindspore.common.initializer import Normal
from mindspore import context, save_checkpoint, ops, Tensor
import mindspore.dataset as ds
import mindspore.dataset.vision as CV
import mindspore.dataset.transforms as C
from mindspore import dtype as mstype
def create_dataset(data_path, batch_size=24, repeat_num=1):
"""定义数据集"""
data_set = ds.ImageFolderDataset(data_path, num_parallel_workers=8, shuffle=True)
image_size = [100, 100]
mean = [127.5, 127.5, 127.5]
std = [255., 255., 255.]
trans = [
CV.Decode(),
CV.Resize(image_size),
CV.Normalize(mean=mean, std=std),
CV.HWC2CHW()
]
# 实现数据的map映射、批量处理和数据重复的操作
type_cast_op = C.TypeCast(mstype.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_wor

本文介绍使用Python和MindSpore实现CNN图像分类模型的训练。首先用ImageFolderDataset加载图像分类数据集并进行预处理;接着定义由4层卷积和2层全连接组成的CNN网络结构;然后单独定义类计算损失;最后定义训练流程,设置迭代次数等参数进行训练,并在每轮结束后验证模型准确率,还可设置运行设备。
最低0.47元/天 解锁文章
2501





