kaggle竞赛系列基于图像对水稻分类代码案例

目录

依赖环境

代码

导入依赖包

定义数据集路径:

创建训练集、验证集和测试集的文件夹:

代码的作用:

设置新的数据集路径与类别名称

代码的作用:

定义数据预处理和增强变换:

代码的作用:

定义数据集评估划分与batch大小

代码的作用:

可视化

代码的作用:

 评估可视化

代码的作用:

网络结构定义

代码的作用:

定义损失函数和优化器,并训练模型

 模型可视化评估

代码的作用:

下载地址:

python深度学习pytorch水稻图像分类完整案例


依赖环境

!pip install split-folders
!pip install torch-summary
!pip install torch matplotplib

代码

导入依赖包

import os
import pathlib
import numpy as np
import splitfolders
import itertools
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from termcolor import colored 
from datetime import datetime
import warnings
from tqdm.notebook import tqdm
from sklearn.metrics import confusion_matrix, classification_report
from torchsummary import summary
from numpy import asarray
from PIL import Image
  • ospathlib 用于文件和目录操作。
  • numpy 用于数组和数值计算。
  • splitfolders 用于分割数据集。
  • itertools 提供迭代生成器。
  • matplotlib.pyplot 用于绘图和数据可视化。
  • torchtorch.nntorch.nn.functionaltorch.optimtorchvisiontorchvision.transforms 用于构建和训练神经网络模型。
  • termcolor 用于终端中的彩色输出。
  • datetime 用于时间操作。
  • warnings 用于忽略警告信息。
  • tqdm.notebook 用于显示进度条。
  • sklearn.metrics 提供评估指标,包括混淆矩阵和分类报告。
  • torchsummary 用于总结模型结构。
  • numpy.asarrayPIL.Image 用于图像处理。

定义数据集路径

data = './Rice_Image_Dataset'
data = pathlib.Path(data)

创建训练集、验证集和测试集的文件夹

splitfolders.ratio(input=data, output='rice_imgs', seed=42, ratio=(0.7, 0.15, 0.15))
  • 使用 splitfolders.ratio 函数将数据集按照 7:1.5:1.5 的比例划分为训练集、验证集和测试集。
  • input 参数指定输入数据集的路径。
  • output 参数指定输出文件夹的名称 'rice_imgs',划分后的数据集将保存在这个文件夹中。
  • seed 参数设置随机种子,以确保划分结果的可重复性。
  • ratio 参数指定训练集、验证集和测试集的比例,分别为 70%、15% 和 15%。

代码的作用:

这段代码通过 splitfolders 库将原始的水稻图像数据集划分为训练集、验证集和测试集,以便在模型训练、验证和测试过程中使用不同的数据子集,从而提高模型的泛化能力和评估准确性。

设置新的数据集路径与类别名称

root_dir = './rice_imgs'
root_dir = pathlib.Path(root_dir)
Arborio='./Rice_Image_Dataset/Arborio'

Arborio_classes=os.listdir(Arborio)
Rice_classes = os.listdir(root_dir)
batchsize=8

from colorama import Fore, Style

print(Fore.GREEN +str(Rice_classes))
print(Fore.YELLOW +"\nTotal number of classes are: ", len(Rice_classes))
  • root_dir 定义为之前划分后的数据集路径 './rice_imgs',并转换为路径对象。
  • Arborio 定义为原始数据集中某个类别的路径。
  • 使用 os.listdir(Arborio) 获取 Arborio 类别中的所有文件和文件夹名称。
  • 使用 os.listdir(root_dir) 获取新的数据集路径中的所有类别名称。
  • 导入 colorama 库中的 ForeStyle,用于终端输出的颜色设置。

代码的作用:

这段代码的主要作用是设置新的数据集路径,并获取数据集中各个类别的名称,以便在后续的数据加载和处理过程中使用。此外,代码还打印出了类别名称和总数,以便进行检查和验证。

定义数据预处理和增强变换

transform = transforms.Compose(
    [
        transforms.Resize((250,250)),
        transforms.ToTensor(),
        transforms.Normalize((0),(1)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(30),
    ]
)
  • transforms.Compose:将多个变换操作组合在一起。
  • transforms.Resize((250,250)):将图像调整为 250x250 的固定尺寸。
  • transforms.ToTensor():将图像转换为 PyTorch 的张量格式,并将像素值归一化到 [0,1] 的范围。
  • transforms.Normalize((0),(1)):标准化图像,使图像的每个通道均值为 0,标准差为 1。
  • transforms.RandomHorizontalFlip():随机水平翻转图像,用于数据增强,以增加模型的泛化能力。
  • transforms.RandomRotation(30):随机旋转图像最多 30 度,用于数据增强。

代码的作用:

这段代码通过定义一个数据预处理和增强的变换流水线,在加载图像数据时自动对图像进行调整大小、转换为张量、标准化、随机水平翻转和随机旋转等操作。这些预处理和增强操作有助于提高模型的训练效果和泛化能力。

定义数据集评估划分与batch大小

import torch.utils.data
batch_size = 32

# Read train images as a dataset
train_set = torchvision.datasets.ImageFolder(
    os.path.join(root_dir, 'train'), transform=transform
)
# Create a Data Loader
train_loader = torch.utils.data.DataLoader(
    train_set,  batch_size = batch_size, shuffle=True
)
print(colored(f'Train Folder :\n ', 'green', attrs=['bold']))
print(train_set)
print('_'*100)

#############################################################
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

E寻数据

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值