Pytorch基础应用

1.数据加载

1.1 读取文本文件

  • 方法一:使用 open() 函数和 read() 方法
# 打开文件并读取全部内容
file_path = 'example.txt'  # 替换为你的文件路径
with open(file_path, 'r') as file:
    content = file.read()
    print(content)
  • 方法二:逐行读取文件内容
# 逐行读取文件内容
file_path = 'example.txt'  # 替换为你的文件路径
with open(file_path, 'r') as file:
    for line in file:
        print(line.strip())  # strip() 方法用于去除行末尾的换行符
  • 方法三:指定编码读取文件内容(如果文本文件不是UTF-8编码)
# 指定编码读取文件内容
file_path = 'example.txt'  # 替换为你的文件路径
with open(file_path, 'r', encoding='utf-8') as file:
    content = file.read()
    print(content)
  • 方法四:一次读取多行内容
# 一次读取多行内容
file_path = 'example.txt'  # 替换为你的文件路径
with open(file_path, 'r') as file:
    lines = file.readlines()
    for line in lines:
        print(line.strip())  # strip() 方法用于去除行末尾的换行符

1.2 读取图片

  • 使用Pillow(PIL)库

安装了Pillow库:pip install Pillow

from PIL import Image

# 打开图片文件
img = Image.open('example.jpg')  # 替换成你的图片文件路径

# 显示图片信息
print("图片格式:", img.format)
print("图片大小:", img.size)
print("图片模式:", img.mode)

# 显示图片
img.show()

# 转换为numpy数组(如果需要在其他库中处理图像)
import numpy as np
img_array = np.array(img)
  • 使用OpenCV库
    安装OpenCV库:pip install opencv-python
import cv2

# 读取图片
img = cv2.imread('example.jpg')  # 替换成你的图片文件路径

# 显示图片信息
print("图片尺寸:", img.shape)  # 高度、宽度、通道数(rgb默认是3)
print("像素值范围:", img.dtype)  # 数据类型(像素值类型)

# 可选:显示图片(OpenCV中显示图像的方法)
cv2.imshow('image', img)
cv2.waitKey(0)
cv2.destroyAllWindows()

# 可选:将BGR格式转换为RGB格式
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# 可选:保存图片
cv2.imwrite('output.jpg', img)
  • 应用
# 继承Dataset类
class MyData(Dataset):
    # 初始化
    def __init__(self, root_dir, image_dir, label_dir):
        # root_dir 训练目录地址
        self.root_dir = root_dir
        # 图片地址
        self.image_dir = image_dir
        # label_dir 标签地址
        self.label_dir = label_dir
        # 训练样本地址
        self.path = os.path.join(self.root_dir, self.image_dir)
        # 将所有样本地址(名称)变成一个列表
        self.img_path = os.listdir(self.path)

    # 读取图片,并且对应label
    # idx 图片下标
    def __getitem__(self, idx):
        # 图片名称
        img_name = self.img_path[idx]
        # 图片路径
        img_item_path = os.path.join(self.path, img_name)
        # 获取图片
        img = Image.open(img_item_path)
        # 打开文件并读取全部内容(文件存取对应样本的标签)
        file_path = os.path.join(self.root_dir, self.label_dir, img_name.split('.jpg')[0])
        with open(file_path + ".txt", 'r') as file:
            label = file.read()
        # 获取标签
        return img, label

    # 训练样本长度
    def __len__(self):
        return len(self.img_path)


root_dir = "hymenoptera_data/train"
ants_image_dir = "ants_image"
ants_label_dir = "ants_label"
bees_image_dir = "bees_image"
bees_label_dir = "bees_label"
ants_dataset = MyData(root_dir, ants_image_dir, ants_label_dir)
img, label = ants_dataset[0]  # 返回数据集第一个样本的图片和标签
img.show()  # 展示图片
bees_dataset = MyData(root_dir, bees_image_dir, bees_label_dir)
import cv2
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from PIL import Image

img_path = "dataset/train/ants_image/0013035.jpg"
# 打开图片
img = Image.open(img_path)

# 利用cv2打开图片,直接是ndarray格式
img_cv2 = cv2.imread(img_path)

print(img_cv2)

write = SummaryWriter('logs')

# 创建一个ToTensor的对象(PIL Image or ndarray to tensor)
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)
print(tensor_img)

1.3 数据可视化(TensorBoard)

SummaryWriter 是 TensorBoard 的日志编写器,用于创建可视化和跟踪模型训练过程中的指标和结果。它通常用于记录模型的训练损失、准确率、权重分布、梯度分布等信息,方便后续在 TensorBoard 中进行可视化分析。

import torch
from torch.utils.tensorboard import SummaryWriter

# 创建一个 SummaryWriter 对象,指定保存日志的路径
writer = SummaryWriter('logs')

# 示例:记录模型的训练过程中的损失值和准确率
for step in range(100):
    # 模拟训练过程中的损失值和准确率
    loss = 0.4 * (100 - step) + torch.rand(1).item()
    accuracy = 0.6 * (step / 100) + torch.rand(1).item()

    # 将损失值和准确率写入日志
    writer.add_scalar('Loss/train', loss, step)
    writer.add_scalar('Accuracy/train', accuracy, step)

# 关闭 SummaryWriter
writer.close()

首先,我们导入了需要的库,包括 PyTorch 和 SummaryWriter。
然后,创建了一个 SummaryWriter 对象,并指定了保存日志的路径 ‘logs’。
在模拟的训练过程中,我们循环了 100 步,每一步模拟生成一个损失值和准确率。
使用 writer.add_scalar 方法将每一步的损失值和准确率写入日志。这些信息将被保存在指定的日志路径中,以便后续在 TensorBoard 中进行查看和分析。
最后,使用 writer.close() 关闭 SummaryWriter,确保日志写入完成并正确保存。

  • SummaryWriter 对象的 add_scalar 方法用于向 TensorBoard 日志中添加标量数据,例如损失值、准确率等。这些标量数据通常用于跟踪模型训练过程中的指标变化。

add_scalar 方法的参数说明:add_scalar(tag, scalar_value, global_step=None, walltime=None)
tag (str):用于标识数据的名称,将作为 TensorBoard 中的图表的名称显示。
scalar_value (float):要记录的标量数据,通常是一个数值。
global_step (int, optional):可选参数,表示记录数据的全局步骤数。主要用于绘制图表时在 x 轴上显示步骤数,方便对训练过程进行时间序列分析。
walltime (float, optional):可选参数,表示记录数据的时间戳。默认情况下,使用当前时间戳。

  • add_image 函数用于向 TensorBoard 日志中添加图像数据,这对于监视模型输入、输出或中间层的可视化非常有用。

add_image 方法的参数说明:add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')
tag (str):用于标识图像的名称,将作为 TensorBoard 中图像的名称显示。
img_tensor (Tensor or numpy.array):要记录的图像数据,可以是 PyTorch 的 Tensor 对象或者 numpy 数组。如果是 numpy 数组,会自动转换为 Tensor。
global_step (int, optional):可选参数,表示记录数据的全局步骤数。主要用于在 TensorBoard 中按时间显示图像。
walltime (float, optional):可选参数,表示记录数据的时间戳。默认情况下,使用当前时间戳。
dataformats (str, optional):可选参数,指定图像的数据格式。默认为 ‘CHW’,即通道-高度-宽度的顺序。

启动命令:tensorboard --logdir 文件路径

1.4 Transforms

transforms.py 文件通常是指用于进行数据预处理和数据增强的模块。这个模块通常用于处理图像数据,包括但不限于加载、转换、裁剪、标准化等操作,以便将数据准备好用于模型的训练或评估。

常见的 transforms.py 功能包括:

  • 数据加载和预处理:
    读取图像数据并转换为 PyTorch 的 Tensor 格式。
    对图像进行大小调整、裁剪、旋转、镜像翻转等操作。
    将图像数据标准化为特定的均值和标准差。
  • 数据增强:
    随机裁剪和大小调整,以增加训练数据的多样性。
    随机水平或垂直翻转图像。
    添加噪声或扭曲以增加数据的多样性。
  • 转换为 Tensor:
    将 PIL 图像或 numpy 数组转换为 PyTorch 的 Tensor 格式。
    对图像数据进行归一化,例如将像素值缩放到 [0, 1] 或 [-1, 1] 之间。
  • 组合多个转换操作:
    将多个转换操作组合成一个流水线,可以顺序应用到图像数据上。
  • 实时数据增强:
    在每次训练迭代中实时生成增强后的数据,而不是预先对所有数据进行转换。
import torch
from torchvision import transforms
from PIL import Image

# 示例图像路径
img_path = 'example.jpg'

# 定义数据转换
data_transform = transforms.Compose([
    transforms.Resize((256, 256)),      # 调整图像大小为 256x256 像素
    transforms.RandomCrop(224),         # 随机裁剪为 224x224 像素
    transforms.RandomHorizontalFlip(),  # 随机水平翻转图像
    transforms.ToTensor(),              # 将图像转换为 Tensor,并归一化至 [0, 1]
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])

# 加载并预处理图像
img = Image.open(img_path)
img_transformed = data_transform(img)

# 打印转换后的图像形状和数据类型
print("Transformed image shape:", img_transformed.shape)
print("Transformed image data type:", img_transformed.dtype)

1.首先,导入了必要的库和模块,包括 PyTorch 的 transforms 模块、PIL 库中的 Image。
定义了一个 transforms.Compose 对象 data_transform,它是一个包含多个转换操作的列表,按顺序应用到输入数据上。
2.示例中的转换操作包括将图像调整为指定大小、随机裁剪为 224x224 像素、随机水平翻转、转换为 PyTorch 的 Tensor 对象,并进行归一化。
3.加载示例图像,并将 data_transform 应用到图像上,得到转换后的图像数据 img_transformed。
4.最后,打印转换后的图像形状和数据类型,以确保转换操作正确应用。

transforms.Normalize(mean, std, inplace=False)

Normalize 类用于对图像数据进行标准化操作。标准化是一种常见的数据预处理步骤,通常用于将数据缩放到一个较小的范围,以便于模型训练的稳定性和收敛速度。在图像处理中,Normalize 类主要用于将图像的每个通道进行均值和标准差的归一化处理

  1. mean:用于归一化的均值。可以是一个列表或元组,每个元素分别对应于图像的每个通道的均值。例如 [mean_channel1, mean_channel2, mean_channel3]。
  2. std:用于归一化的标准差。同样是一个列表或元组,每个元素对应于图像的每个通道的标准差。例如 [std_channel1, std_channel2, std_channel3]。
  3. inplace:是否原地操作。默认为 False,表示会返回一个新的标准化后的图像。如果设置为 True,则会直接修改输入的 Tensor。

Compose类
transforms.Compose(transforms)
transforms:一个由 torchvision.transforms 中的转换操作组成的列表或序列,每个操作会依次应用到输入数据上。

1.5 torchvision.datasets

PyTorch 中用于加载和处理常见视觉数据集的模块。这个模块提供了对多种经典数据集的访问接口,包括图像分类、物体检测、语义分割等任务常用的数据集。

  • 数据集加载:
    torchvision.datasets 提供了方便的接口,可以直接从互联网下载和加载常用的数据集,例如 MNIST、CIFAR-10、ImageNet 等。
  • 数据预处理:
    支持在加载数据集时进行数据预处理,例如图像大小调整、裁剪、翻转、归一化等操作,这些操作可以通过 transforms 模块进行定义和组合。
  • 数据集管理:
    提供了便捷的方法来管理和访问数据集,例如对数据集进行随机访问、按照批次加载数据等,以支持机器学习模型的训练和评估。
  • 多样的数据集支持:
    支持多种类型的数据集,包括但不限于分类数据集(如 MNIST、CIFAR-10)、检测数据集(如 COCO)、语义分割数据集(如 Pascal VOC)等,适用于不同的视觉任务。
# 以CIFAR-10为例
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# 定义数据预处理
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Yoin.

感谢各位打赏!!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值