mac深度学习pytorch:ResNet50网络图像分类篇——花数据集图像分类(5类)保姆级详细过程及可视化

原代码及数据集下载:【度学习pytorch实战六:ResNet50网络图像分类篇自建花数据集图像分类(5类)超详细代码 -  优快云 App】/windows度学习pytorch实战六:ResNet50网络图像分类篇自建花数据集图像分类(5类)超详细代码_resnet50图像分类-优快云博客

1.数据集

下载数据集至本地(点击原文链接进行下载)

2.代码

2.1数据预处理

**windows:**
import os
from shutil import copy
import random

def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)

# 获取 photos 文件夹下除 .txt 文件以外所有文件夹名(即3种分类的类名)
file_path = 'yourdataroad/flower_photos'
flower_class = [cla for cla in os.listdir(file_path) if ".txt" not in cla]

# 创建 训练集train 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/train')
for cla in flower_class:
    mkfile('flower_data/train/' + cla)

# 创建 验证集val 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/val')
for cla in flower_class:
    mkfile('flower_data/val/' + cla)

# 划分比例,训练集 : 验证集 = 9 : 1
split_rate = 0.1

# 遍历3种花的全部图像并按比例分成训练集和验证集
for cla in flower_class:
    cla_path = file_path + '/' + cla + '/'  # 某一类别动作的子目录
    images = os.listdir(cla_path)  # iamges 列表存储了该目录下所有图像的名称
    num = len(images)
    eval_index = random.sample(images, k=int(num * split_rate))  # 从images列表中随机抽取 k 个图像名称
    for index, image in enumerate(images):
        # eval_index 中保存验证集val的图像名称
        if image in eval_index:
            image_path = cla_path + image
            new_path = 'flower_data/val/' + cla
            copy(image_path, new_path)  # 将选中的图像复制到新路径

        # 其余的图像保存在训练集train中
        else:
            image_path = cla_path + image
            new_path = 'flower_data/train/' + cla
            copy(image_path, new_path)
        print("\\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="")  # processing bar
    print()

print("processing done!")

mac设备报错:

 File "/*/pytorch/resnet/data预处理.py", line 32, in <module>
    images = os.listdir(cla_path)  # iamges 列表存储了该目录下所有图像的名称
NotADirectoryError: [Errno 20] Not a directory: '/yourdataroad/flower_data/.DS_Store/'

错误信息表明在尝试列出目录内容时,cla_path 指向了一个不是目录的文件,具体来说是 /yourdataroad/flower_data/.DS_Store/.DS_Store 文件是 macOS 系统用来存储文件夹的自定义属性的隐藏文件,例如文件的图标位置和背景色等,它不是一个目录。

要修正这个错误,你需要在获取目录列表时排除 .DS_Store 文件

**macos**
import os
from shutil import copy
import random

def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)

# 获取 photos 文件夹下除 .txt 文件以外所有文件夹名(即3种分类的类名)
file_path = '/yourdataroad/flower_data'
flower_class = [cla for cla in os.listdir(file_path) if os.path.isdir(os.path.join(file_path, cla)) and ".txt" not in cla]

# 创建 训练集train 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/train')
for cla in flower_class:
    mkfile('flower_data/train/' + cla)

# 创建 验证集val 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/val')
for cla in flower_class:
    mkfile('flower_data/val/' + cla)

# 划分比例,训练集 : 验证集 = 9 : 1
split_rate = 0.1

# 遍历3种花的全部图像并按比例分成训练集和验证集
for cla in flower_class:
    cla_path = os.path.join(file_path, cla)  # 某一类别动作的子目录
    images = [img for img in os.listdir(cla_path) if os.path.isfile(os.path.join(cla_path, img))]  # 确保是文件
    num = len(images)
    eval_index = random.sample(images, k=int(num * split_rate))  # 从images列表中随机抽取 k 个图像名称
    for index, image in enumerate(images):
        # eval_index 中保存验证集val的图像名称
        if image in eval_index:
            image_path = os.path.join(cla_path, image)
            new_path = os.path.join('flower_data/val', cla)
            copy(image_path, new_path)  # 将选中的图像复制到新路径

        # 其余的图像保存在训练集train中
        else:
            image_path = os.path.join(cla_path, image)
            new_path = os.path.join('flower_data/train', cla)
            copy(image_path, new_path)
        print("\\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="")  # processing bar
    print()

print("processing done!")

代码运行结果:

2.2训练模型

假设你已知restnet网络原理,不知道也没关系。以下是在macos上运行的完整代码,windows或其他版本可以去看本文开头链接。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os

# 定义数据转换
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 加载数据集
data_dir = '/Users/xiaqizai/Downloads/flower_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,
                             shuffle=True, num_workers=4)
               for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

# 加载预训练的ResNet-50模型
model = models.resnet50(weights=model
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值