目录
依赖环境
!pip install split-folders
!pip install torch-summary
!pip install torch matplotplib
代码
导入依赖包
import os
import pathlib
import numpy as np
import splitfolders
import itertools
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from termcolor import colored
from datetime import datetime
import warnings
from tqdm.notebook import tqdm
from sklearn.metrics import confusion_matrix, classification_report
from torchsummary import summary
from numpy import asarray
from PIL import Image
os
和pathlib
用于文件和目录操作。numpy
用于数组和数值计算。splitfolders
用于分割数据集。itertools
提供迭代生成器。matplotlib.pyplot
用于绘图和数据可视化。torch
、torch.nn
、torch.nn.functional
、torch.optim
、torchvision
和torchvision.transforms
用于构建和训练神经网络模型。termcolor
用于终端中的彩色输出。datetime
用于时间操作。warnings
用于忽略警告信息。tqdm.notebook
用于显示进度条。sklearn.metrics
提供评估指标,包括混淆矩阵和分类报告。torchsummary
用于总结模型结构。numpy.asarray
和PIL.Image
用于图像处理。
定义数据集路径:
data = './Rice_Image_Dataset'
data = pathlib.Path(data)
创建训练集、验证集和测试集的文件夹:
splitfolders.ratio(input=data, output='rice_imgs', seed=42, ratio=(0.7, 0.15, 0.15))
- 使用
splitfolders.ratio
函数将数据集按照 7:1.5:1.5 的比例划分为训练集、验证集和测试集。 input
参数指定输入数据集的路径。output
参数指定输出文件夹的名称'rice_imgs'
,划分后的数据集将保存在这个文件夹中。seed
参数设置随机种子,以确保划分结果的可重复性。ratio
参数指定训练集、验证集和测试集的比例,分别为 70%、15% 和 15%。
代码的作用:
这段代码通过 splitfolders
库将原始的水稻图像数据集划分为训练集、验证集和测试集,以便在模型训练、验证和测试过程中使用不同的数据子集,从而提高模型的泛化能力和评估准确性。
设置新的数据集路径与类别名称
root_dir = './rice_imgs'
root_dir = pathlib.Path(root_dir)
Arborio='./Rice_Image_Dataset/Arborio'
Arborio_classes=os.listdir(Arborio)
Rice_classes = os.listdir(root_dir)
batchsize=8
from colorama import Fore, Style
print(Fore.GREEN +str(Rice_classes))
print(Fore.YELLOW +"\nTotal number of classes are: ", len(Rice_classes))
root_dir
定义为之前划分后的数据集路径'./rice_imgs'
,并转换为路径对象。Arborio
定义为原始数据集中某个类别的路径。- 使用
os.listdir(Arborio)
获取Arborio
类别中的所有文件和文件夹名称。 - 使用
os.listdir(root_dir)
获取新的数据集路径中的所有类别名称。 - 导入
colorama
库中的Fore
和Style
,用于终端输出的颜色设置。
代码的作用:
这段代码的主要作用是设置新的数据集路径,并获取数据集中各个类别的名称,以便在后续的数据加载和处理过程中使用。此外,代码还打印出了类别名称和总数,以便进行检查和验证。
定义数据预处理和增强变换:
transform = transforms.Compose(
[
transforms.Resize((250,250)),
transforms.ToTensor(),
transforms.Normalize((0),(1)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(30),
]
)
transforms.Compose
:将多个变换操作组合在一起。transforms.Resize((250,250))
:将图像调整为 250x250 的固定尺寸。transforms.ToTensor()
:将图像转换为 PyTorch 的张量格式,并将像素值归一化到 [0,1] 的范围。transforms.Normalize((0),(1))
:标准化图像,使图像的每个通道均值为 0,标准差为 1。transforms.RandomHorizontalFlip()
:随机水平翻转图像,用于数据增强,以增加模型的泛化能力。transforms.RandomRotation(30)
:随机旋转图像最多 30 度,用于数据增强。
代码的作用:
这段代码通过定义一个数据预处理和增强的变换流水线,在加载图像数据时自动对图像进行调整大小、转换为张量、标准化、随机水平翻转和随机旋转等操作。这些预处理和增强操作有助于提高模型的训练效果和泛化能力。
定义数据集评估划分与batch大小
import torch.utils.data
batch_size = 32
# Read train images as a dataset
train_set = torchvision.datasets.ImageFolder(
os.path.join(root_dir, 'train'), transform=transform
)
# Create a Data Loader
train_loader = torch.utils.data.DataLoader(
train_set, batch_size = batch_size, shuffle=True
)
print(colored(f'Train Folder :\n ', 'green', attrs=['bold']))
print(train_set)
print('_'*100)
#############################################################