常见的Transforms(一)
1. python中__call__
的用法
代码展示:
# 定义一个名为 Person 的类
class Person:
# 定义 __call__ 方法,该方法使得类的实例可以像函数一样被调用
def __call__(self, name):
print("__call__: " + name) # 当实例被调用时,打印 "__call__: " 加上传入的 name 参数
# 定义一个名为 hello 的方法,该方法接受一个参数 name
def hello(self, name):
print("Hello " + name) # 打印 "Hello " 加上传入的 name 参数
# 创建 Person 类的一个实例,赋值给变量 person
person = Person()
# 调用 person 实例,传入参数 "zhangsan",由于 Person 类定义了 __call__ 方法,因此实例可以像函数一样被调用
person("zhangsan") # 输出: __call__: zhangsan
# 调用 person 实例的 hello 方法,传入参数 "lisi"
person.hello("lisi") # 输出: Hello lisi
说明:
__call__
方法使得类的实例可以像函数一样被调用hello
是一个普通的方法,需要通过实例来调用
2. ToTensor的使用
3. Normalize的使用
归一化:归一化的目的就是为了让不同的特征在数值上保持一致,避免某些特征对模型的影响过大,从而更好地学习到数据中的模式和关系
3.1. 代码展示与说明
# 导入所需的库
from PIL import Image # 用于图像处理的库
from torch.utils.tensorboard import SummaryWriter # 用于记录日志并可视化到 TensorBoard
from torchvision import transforms # 提供常用的图像预处理方法
# 创建一个 SummaryWriter 对象,用于将日志写入 "logs" 目录
writer = SummaryWriter("logs")
# 定义图像路径
img_path = "dataset/train/ants/0013035.jpg"
# 使用 PIL 的 Image.open 方法打开图像
img = Image.open(img_path)
print(img) # 打印图像对象的信息(如格式、大小等)
# 创建一个 transforms.ToTensor() 对象,用于将 PIL 图像或 NumPy 数组转换为 PyTorch 张量
trans_totensor = transforms.ToTensor()
# 将图像转换为张量
img_tensor = trans_totensor(img)
# 打印张量的第一个通道的第一个像素的值
print(img_tensor[0][0][0]) # 输出张量的第一个通道的第一个像素的值
# 创建一个 transforms.Normalize 对象,用于对张量进行归一化
# 参数:
# - mean: 一个列表,表示每个通道的均值。例如 [mean_ch1, mean_ch2, mean_ch3]。
# - std: 一个列表,表示每个通道的标准差。例如 [std_ch1, std_ch2, std_ch3]。
# 计算公式:output[channel] = (input[channel] - mean[channel]) / std[channel]
# 这里的 [0.5, 0.5, 0.5] 是均值和标准差,表示对每个通道进行归一化。
# 归一化后的像素值范围通常为 [-1, 1],因为:
# - 输入像素值范围是 [0, 1](由 ToTensor 转换得到)。
# - 归一化公式为 (input - 0.5) / 0.5 = 2 * input - 1,因此结果范围为 [-1, 1]。
trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
# 对张量进行归一化
img_norm = trans_norm(img_tensor)
# 打印归一化后的张量的第一个通道的第一个像素的值
print(img_norm[0][0][0]) # 输出归一化后的张量的第一个通道的第一个像素的值
# 将归一化后的图像张量添加到 TensorBoard 中,标签为 "Normalize"
writer.add_image("Normalize", img_norm)
# 关闭 SummaryWriter 对象,确保所有日志都被写入
writer.close()
3.2. 输出结果
3.3. 查看图片
终端输入命令:tensorboard --logdir=logs
进入tensorboard
中查看图片: