1.数据加载
1.1 读取文本文件
- 方法一:使用 open() 函数和 read() 方法
# 打开文件并读取全部内容
file_path = 'example.txt' # 替换为你的文件路径
with open(file_path, 'r') as file:
content = file.read()
print(content)
- 方法二:逐行读取文件内容
# 逐行读取文件内容
file_path = 'example.txt' # 替换为你的文件路径
with open(file_path, 'r') as file:
for line in file:
print(line.strip()) # strip() 方法用于去除行末尾的换行符
- 方法三:指定编码读取文件内容(如果文本文件不是UTF-8编码)
# 指定编码读取文件内容
file_path = 'example.txt' # 替换为你的文件路径
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()
print(content)
- 方法四:一次读取多行内容
# 一次读取多行内容
file_path = 'example.txt' # 替换为你的文件路径
with open(file_path, 'r') as file:
lines = file.readlines()
for line in lines:
print(line.strip()) # strip() 方法用于去除行末尾的换行符
1.2 读取图片
- 使用Pillow(PIL)库
安装了Pillow库:pip install Pillow
from PIL import Image
# 打开图片文件
img = Image.open('example.jpg') # 替换成你的图片文件路径
# 显示图片信息
print("图片格式:", img.format)
print("图片大小:", img.size)
print("图片模式:", img.mode)
# 显示图片
img.show()
# 转换为numpy数组(如果需要在其他库中处理图像)
import numpy as np
img_array = np.array(img)
- 使用OpenCV库
安装OpenCV库:pip install opencv-python
import cv2
# 读取图片
img = cv2.imread('example.jpg') # 替换成你的图片文件路径
# 显示图片信息
print("图片尺寸:", img.shape) # 高度、宽度、通道数(rgb默认是3)
print("像素值范围:", img.dtype) # 数据类型(像素值类型)
# 可选:显示图片(OpenCV中显示图像的方法)
cv2.imshow('image', img)
cv2.waitKey(0)
cv2.destroyAllWindows()
# 可选:将BGR格式转换为RGB格式
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 可选:保存图片
cv2.imwrite('output.jpg', img)
- 应用
# 继承Dataset类
class MyData(Dataset):
# 初始化
def __init__(self, root_dir, image_dir, label_dir):
# root_dir 训练目录地址
self.root_dir = root_dir
# 图片地址
self.image_dir = image_dir
# label_dir 标签地址
self.label_dir = label_dir
# 训练样本地址
self.path = os.path.join(self.root_dir, self.image_dir)
# 将所有样本地址(名称)变成一个列表
self.img_path = os.listdir(self.path)
# 读取图片,并且对应label
# idx 图片下标
def __getitem__(self, idx):
# 图片名称
img_name = self.img_path[idx]
# 图片路径
img_item_path = os.path.join(self.path, img_name)
# 获取图片
img = Image.open(img_item_path)
# 打开文件并读取全部内容(文件存取对应样本的标签)
file_path = os.path.join(self.root_dir, self.label_dir, img_name.split('.jpg')[0])
with open(file_path + ".txt", 'r') as file:
label = file.read()
# 获取标签
return img, label
# 训练样本长度
def __len__(self):
return len(self.img_path)
root_dir = "hymenoptera_data/train"
ants_image_dir = "ants_image"
ants_label_dir = "ants_label"
bees_image_dir = "bees_image"
bees_label_dir = "bees_label"
ants_dataset = MyData(root_dir, ants_image_dir, ants_label_dir)
img, label = ants_dataset[0] # 返回数据集第一个样本的图片和标签
img.show() # 展示图片
bees_dataset = MyData(root_dir, bees_image_dir, bees_label_dir)
import cv2
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from PIL import Image
img_path = "dataset/train/ants_image/0013035.jpg"
# 打开图片
img = Image.open(img_path)
# 利用cv2打开图片,直接是ndarray格式
img_cv2 = cv2.imread(img_path)
print(img_cv2)
write = SummaryWriter('logs')
# 创建一个ToTensor的对象(PIL Image or ndarray to tensor)
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)
print(tensor_img)
1.3 数据可视化(TensorBoard)
SummaryWriter 是 TensorBoard 的日志编写器,用于创建可视化和跟踪模型训练过程中的指标和结果。它通常用于记录模型的训练损失、准确率、权重分布、梯度分布等信息,方便后续在 TensorBoard 中进行可视化分析。
import torch
from torch.utils.tensorboard import SummaryWriter
# 创建一个 SummaryWriter 对象,指定保存日志的路径
writer = SummaryWriter('logs')
# 示例:记录模型的训练过程中的损失值和准确率
for step in range(100):
# 模拟训练过程中的损失值和准确率
loss = 0.4 * (100 - step) + torch.rand(1).item()
accuracy = 0.6 * (step / 100) + torch.rand(1).item()
# 将损失值和准确率写入日志
writer.add_scalar('Loss/train', loss, step)
writer.add_scalar('Accuracy/train', accuracy, step)
# 关闭 SummaryWriter
writer.close()
首先,我们导入了需要的库,包括 PyTorch 和 SummaryWriter。
然后,创建了一个 SummaryWriter 对象,并指定了保存日志的路径 ‘logs’。
在模拟的训练过程中,我们循环了 100 步,每一步模拟生成一个损失值和准确率。
使用 writer.add_scalar 方法将每一步的损失值和准确率写入日志。这些信息将被保存在指定的日志路径中,以便后续在 TensorBoard 中进行查看和分析。
最后,使用 writer.close() 关闭 SummaryWriter,确保日志写入完成并正确保存。
- SummaryWriter 对象的 add_scalar 方法用于向 TensorBoard 日志中添加标量数据,例如损失值、准确率等。这些标量数据通常用于跟踪模型训练过程中的指标变化。
add_scalar 方法的参数说明:
add_scalar(tag, scalar_value, global_step=None, walltime=None)
tag (str):用于标识数据的名称,将作为 TensorBoard 中的图表的名称显示。
scalar_value (float):要记录的标量数据,通常是一个数值。
global_step (int, optional):可选参数,表示记录数据的全局步骤数。主要用于绘制图表时在 x 轴上显示步骤数,方便对训练过程进行时间序列分析。
walltime (float, optional):可选参数,表示记录数据的时间戳。默认情况下,使用当前时间戳。
- add_image 函数用于向 TensorBoard 日志中添加图像数据,这对于监视模型输入、输出或中间层的可视化非常有用。
add_image 方法的参数说明:
add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')
tag (str):用于标识图像的名称,将作为 TensorBoard 中图像的名称显示。
img_tensor (Tensor or numpy.array):要记录的图像数据,可以是 PyTorch 的 Tensor 对象或者 numpy 数组。如果是 numpy 数组,会自动转换为 Tensor。
global_step (int, optional):可选参数,表示记录数据的全局步骤数。主要用于在 TensorBoard 中按时间显示图像。
walltime (float, optional):可选参数,表示记录数据的时间戳。默认情况下,使用当前时间戳。
dataformats (str, optional):可选参数,指定图像的数据格式。默认为 ‘CHW’,即通道-高度-宽度的顺序。
启动命令:tensorboard --logdir 文件路径
1.4 Transforms
transforms.py 文件通常是指用于进行数据预处理和数据增强的模块。这个模块通常用于处理图像数据,包括但不限于加载、转换、裁剪、标准化等操作,以便将数据准备好用于模型的训练或评估。
常见的 transforms.py 功能包括:
- 数据加载和预处理:
读取图像数据并转换为 PyTorch 的 Tensor 格式。
对图像进行大小调整、裁剪、旋转、镜像翻转等操作。
将图像数据标准化为特定的均值和标准差。 - 数据增强:
随机裁剪和大小调整,以增加训练数据的多样性。
随机水平或垂直翻转图像。
添加噪声或扭曲以增加数据的多样性。 - 转换为 Tensor:
将 PIL 图像或 numpy 数组转换为 PyTorch 的 Tensor 格式。
对图像数据进行归一化,例如将像素值缩放到 [0, 1] 或 [-1, 1] 之间。 - 组合多个转换操作:
将多个转换操作组合成一个流水线,可以顺序应用到图像数据上。 - 实时数据增强:
在每次训练迭代中实时生成增强后的数据,而不是预先对所有数据进行转换。
import torch
from torchvision import transforms
from PIL import Image
# 示例图像路径
img_path = 'example.jpg'
# 定义数据转换
data_transform = transforms.Compose([
transforms.Resize((256, 256)), # 调整图像大小为 256x256 像素
transforms.RandomCrop(224), # 随机裁剪为 224x224 像素
transforms.RandomHorizontalFlip(), # 随机水平翻转图像
transforms.ToTensor(), # 将图像转换为 Tensor,并归一化至 [0, 1]
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
])
# 加载并预处理图像
img = Image.open(img_path)
img_transformed = data_transform(img)
# 打印转换后的图像形状和数据类型
print("Transformed image shape:", img_transformed.shape)
print("Transformed image data type:", img_transformed.dtype)
1.首先,导入了必要的库和模块,包括 PyTorch 的 transforms 模块、PIL 库中的 Image。
定义了一个 transforms.Compose 对象 data_transform,它是一个包含多个转换操作的列表,按顺序应用到输入数据上。
2.示例中的转换操作包括将图像调整为指定大小、随机裁剪为 224x224 像素、随机水平翻转、转换为 PyTorch 的 Tensor 对象,并进行归一化。
3.加载示例图像,并将 data_transform 应用到图像上,得到转换后的图像数据 img_transformed。
4.最后,打印转换后的图像形状和数据类型,以确保转换操作正确应用。
transforms.Normalize(mean, std, inplace=False)
Normalize 类用于对图像数据进行标准化操作。标准化是一种常见的数据预处理步骤,通常用于将数据缩放到一个较小的范围,以便于模型训练的稳定性和收敛速度。在图像处理中,Normalize 类主要用于将图像的每个通道进行均值和标准差的归一化处理
- mean:用于归一化的均值。可以是一个列表或元组,每个元素分别对应于图像的每个通道的均值。例如 [mean_channel1, mean_channel2, mean_channel3]。
- std:用于归一化的标准差。同样是一个列表或元组,每个元素对应于图像的每个通道的标准差。例如 [std_channel1, std_channel2, std_channel3]。
- inplace:是否原地操作。默认为 False,表示会返回一个新的标准化后的图像。如果设置为 True,则会直接修改输入的 Tensor。
Compose类
transforms.Compose(transforms)
transforms:一个由 torchvision.transforms 中的转换操作组成的列表或序列,每个操作会依次应用到输入数据上。
1.5 torchvision.datasets
PyTorch 中用于加载和处理常见视觉数据集的模块。这个模块提供了对多种经典数据集的访问接口,包括图像分类、物体检测、语义分割等任务常用的数据集。
- 数据集加载:
torchvision.datasets 提供了方便的接口,可以直接从互联网下载和加载常用的数据集,例如 MNIST、CIFAR-10、ImageNet 等。 - 数据预处理:
支持在加载数据集时进行数据预处理,例如图像大小调整、裁剪、翻转、归一化等操作,这些操作可以通过 transforms 模块进行定义和组合。 - 数据集管理:
提供了便捷的方法来管理和访问数据集,例如对数据集进行随机访问、按照批次加载数据等,以支持机器学习模型的训练和评估。 - 多样的数据集支持:
支持多种类型的数据集,包括但不限于分类数据集(如 MNIST、CIFAR-10)、检测数据集(如 COCO)、语义分割数据集(如 Pascal VOC)等,适用于不同的视觉任务。
# 以CIFAR-10为例
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# 定义数据预处理
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]