目录
目录
前言
- 🍨本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/rbOOmire8OocQ90QM78DRA) 中的学习记录博客
- 🍖 原作者:K同学啊
说在前面:
- 本周学习目标:基本目标——本地读取并加载数据、测试集上的Accuracy达到93;拔高目标——通过调整使得Accuracy达到95%;调用模型识别一张本地的图片
- 学习重点:通过本地数据加载,调整模型参数以提高在测试集上的Accuracy
- 我的环境:Python3.8、Pycharm2020、torch1.12.1+cu113(ps:由于电脑无英伟达显卡,所以这里实际上还是用的cpu在运行)
一、前期准备
1.1 设置GPU
由于电脑硬件原因,这里仅支持cpu运行(首先导入需要的包,然后再查看系统支持的运行设备GPU or CPU)
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, random
import matplotlib.pyplot as plt
from PIL import Image
#Python图像库PIL(Python Image Library)是python的第三方图像处理库
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
输出显示为:cpu
1.2 导入数据
1.首先下载含天气的图片数据集,在代码文件相同的目录下建立一个名为data的目录,将下载的天气图片文件复制到data目录下
- 数据集介绍:该数据集是包括了四类天气的图片,分为为多云、雨、晴、日落四类天气,分别存储在四个子文件夹下,每类天气包含了300张图片,所以一共就是1200张图片
- 导入数据集的步骤如下:
1)使用函数将字符串类型的文件夹路径转换为pathlib.Path对象
2)使用glob方法获取data_dir路径下的所有文件路径,并以列表的形式存储在data_paths中
3)利用split()函数对data_paths中的每个文件路径执行分割操作,获取各个文件所属的类别名称并储存在classNames中
示例代码如下:
#1.2 导入数据
data_dir = './data/'
data_dir = pathlib.Path(data_dir) #使用函数将字符串类型的文件夹路径转换为pathlib.Path对象
data_paths = list(data_dir.glob('*')) #使用glob方法获取data_dir路径下的所有文件路径,并以列表的形式存储在data_paths中
classNames = [str(path).split('\\')[1] for path in data_paths]
#利用split()函数对data_paths中的每个文件路径执行分割操作,获取各个文件所属的类别名称并储存在classNames中
# 4类天气,各300张图片
print(classNames)
打印结果为:['cloudy', 'rain', 'shine', 'sunrise']
2.选取24张图片进行图片打印展示
1.image_files = [f for f in os.listdir(image_folder) if f.endswith((".jpg", ".png", ".jpeg"))]:使用列表推导式加载和显示图像用于在Matplotlib中的多个子图中显示从文件夹中加载的图像 1)在Matplotlib中,plt.subplots()函数用于创建一个包含多个子图的图形对象fig和子图对象的数组axes,当指定行数和列数创建子图时,返回的axes是一个二维数组,其中包含了所有的子图对象 2)axes.flat是一个迭代器,它可以让您按照一维的顺序访问axes数组中的所有子图对象 2.for循环:通过zip()函数将axes.flat和image_files两个可迭代对象进行配对,每次循环从中获取一个子图对象ax和一个图像文件名img_file 1)os.path.join()方法将文件夹路径image_folder 和图像文件名img_file拼接成完整的图像文件路径img_path 2)使用Image.open()方法从指定路径打开图像文件,将其加载为一个图像对象img 3)在当前子图ax中显示加载的图像img 4)关闭子图ax的坐标轴显示,以展现一个干净的画布
示例代码如下:
image_folder = './data/shine/' #指定图像文件夹路径
image_files = [f for f in os.listdir(image_folder) if f.endswith((".jpg", ".png", ".jpeg"))]
fig, axes = plt.subplots(3, 8, figsize=(16, 6))
for ax, img_file in zip(axes.flat, image_files):
img_path = os.path.join(image_folder, img_file)
img = Image.open(img_path)
ax.imshow(img)
ax.axis('off')
plt.tight_layout()
plt.show