在GPU上Lenet5模型Mindspore框架实现

1.数据集

数据使用的是kaggle猫狗分类数据集,包含25000张猫狗图像。

2.代码

以下代码是Lenet模型在Mindspore框架上的简单实现,简单地把mindspore跑通了,还存在优化空间,可能与本实验是在GPU上跑Mindspore模型有关,同样的模型在pytorch上结果很好,但在Mindpore上性能一般。后续准备在华为服务器上用同样代码再进行一下实验。

import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, UnidentifiedImageError
from sklearn.model_selection import train_test_split
from mindspore import nn, context, dataset as ds, Model
from mindspore.common.initializer import Normal
from mindspore.dataset.transforms import transforms
from mindspore.dataset.vision import transforms as vision
from mindspore.nn import Accuracy, Momentum
from mindspore.train.callback import Callback
import mindspore
from sklearn.metrics import confusion_matrix
import seaborn as sns

# 设置运行环境为GPU
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

# 定义LeNet模型
class LeNet5(nn.Cell):
    def __init__(self, num_class=2, num_channel=3):
        super(LeNet5, self).__init__()
        # 定义网络层
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid', weight_init=Normal(0.02))
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid', weight_init=Normal(0.02))
        self.fc1 = nn.Dense(16*5*5, 120, weight_init=Normal(0.02))
        self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
        self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)

    def construct(self, x):
        # 前向传播过程
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

# 加载Cats vs Dogs数据集
class CatsVsDogsDataset:
    def __init__(self, data_dir, split='train', test_size=0.2, random_state=42):
        self.data_dir = data_dir
        self.split = split

        # 获取所有图像文件和对应的标签
        self.image_files = []
        self.labels = []

        cat_dir = os.path.join(data_dir, 'Cat')
        dog_dir = os.path.join(data_dir, 'Dog')

        for file_name in os.listdir(cat_dir):
            file_path = os.path.join(cat_dir, file_name)
            # 仅处理图像文件
            if os.path.isfile(file_path) and file_path.lower().endswith(('.png', '.jpg', '.jpeg')):
                self.image_files.append(file_path)
                self.labels.append(0)  # Cat标签为0

        for file_name in os.listdir(dog_dir):
            file_path = os.path.join(dog_dir, file_name)
            # 仅处理图像文件
            if os.path.isfile(file_path) and file_path.lower().endswith(('.png', '.jpg', '.jpeg')):
                self.image_files.append(file_path)
                self.labels.append(1)  # Dog标签为1

        # 划分训练集和测试集
        train_files, test_files, train_labels, test_labels = train_test_split(
            self.image_files, self.labels, test_size=test_size, random_state=random_state, stratify=self.labels
        )

        if self.split == 'train':
            self.image_files, self.labels = train_files, train_labels
        else:
            self.image_files, self.labels = test_files, test_labels

    def __getitem__(self, index):
        img_path = self.image_files[index]
        label = self.labels[index]

        # 打开图像并处理异常
        try:
            image = Image.open(img_path).convert('RGB')
            image = np.asarray(image)

        except (IOError, UnidentifiedImageError) as e:
            print(f"Error loading image {
     
     img_path}: {
     
     e}")
            # 返回一个空图像和无效标签以避免崩溃
            image = None
            label = None

        return image, label

    def __len__(self):
        return len(self.image_files)

def create_dataset(data_dir, batch_size=32, repeat_size=1, shuffle=True, split='train'):
    dataset = CatsVsDogsDataset(data_dir, split)

    # 过滤掉无法加载的图像
    valid_data = [(image, label) for image, label in dataset if image is not 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值