1.数据集
数据使用的是kaggle猫狗分类数据集,包含25000张猫狗图像。
2.代码
以下代码是Lenet模型在Mindspore框架上的简单实现,简单地把mindspore跑通了,还存在优化空间,可能与本实验是在GPU上跑Mindspore模型有关,同样的模型在pytorch上结果很好,但在Mindpore上性能一般。后续准备在华为服务器上用同样代码再进行一下实验。
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, UnidentifiedImageError
from sklearn.model_selection import train_test_split
from mindspore import nn, context, dataset as ds, Model
from mindspore.common.initializer import Normal
from mindspore.dataset.transforms import transforms
from mindspore.dataset.vision import transforms as vision
from mindspore.nn import Accuracy, Momentum
from mindspore.train.callback import Callback
import mindspore
from sklearn.metrics import confusion_matrix
import seaborn as sns
# 设置运行环境为GPU
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
# 定义LeNet模型
class LeNet5(nn.Cell):
def __init__(self, num_class=2, num_channel=3):
super(LeNet5, self).__init__()
# 定义网络层
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid', weight_init=Normal(0.02))
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid', weight_init=Normal(0.02))
self.fc1 = nn.Dense(16*5*5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
def construct(self, x):
# 前向传播过程
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = x.view(x.shape[0], -1)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
# 加载Cats vs Dogs数据集
class CatsVsDogsDataset:
def __init__(self, data_dir, split='train', test_size=0.2, random_state=42):
self.data_dir = data_dir
self.split = split
# 获取所有图像文件和对应的标签
self.image_files = []
self.labels = []
cat_dir = os.path.join(data_dir, 'Cat')
dog_dir = os.path.join(data_dir, 'Dog')
for file_name in os.listdir(cat_dir):
file_path = os.path.join(cat_dir, file_name)
# 仅处理图像文件
if os.path.isfile(file_path) and file_path.lower().endswith(('.png', '.jpg', '.jpeg')):
self.image_files.append(file_path)
self.labels.append(0) # Cat标签为0
for file_name in os.listdir(dog_dir):
file_path = os.path.join(dog_dir, file_name)
# 仅处理图像文件
if os.path.isfile(file_path) and file_path.lower().endswith(('.png', '.jpg', '.jpeg')):
self.image_files.append(file_path)
self.labels.append(1) # Dog标签为1
# 划分训练集和测试集
train_files, test_files, train_labels, test_labels = train_test_split(
self.image_files, self.labels, test_size=test_size, random_state=random_state, stratify=self.labels
)
if self.split == 'train':
self.image_files, self.labels = train_files, train_labels
else:
self.image_files, self.labels = test_files, test_labels
def __getitem__(self, index):
img_path = self.image_files[index]
label = self.labels[index]
# 打开图像并处理异常
try:
image = Image.open(img_path).convert('RGB')
image = np.asarray(image)
except (IOError, UnidentifiedImageError) as e:
print(f"Error loading image {
img_path}: {
e}")
# 返回一个空图像和无效标签以避免崩溃
image = None
label = None
return image, label
def __len__(self):
return len(self.image_files)
def create_dataset(data_dir, batch_size=32, repeat_size=1, shuffle=True, split='train'):
dataset = CatsVsDogsDataset(data_dir, split)
# 过滤掉无法加载的图像
valid_data = [(image, label) for image, label in dataset if image is not

最低0.47元/天 解锁文章
1724

被折叠的 条评论
为什么被折叠?



