自定义Transforms方法
试想一下,在使用 torchvision.transforms 做图像增强时,我们经常会遇到:官方 API 提供不了的操作(例如添加椒盐噪声、局部模糊、自定义颜色扰动等),这种情况下,就需要我们自定义——transforms。即自定义transforms类并“以参数形式”传到 Dataset 的__init__初始化里。
实践中最推荐的方式是:
👉 写一个和官方 Transform 完全相同接口的类 (__call__)
👉 再通过 Compose() 与其他 transform 串联使用
这样可读性好、扩展性强,也能保持项目代码风格一致。
- 1.transforms的底层核心思想
官网 transforms 看上去很多,但其实底层哲学只有一句:一个可调用对象列表
可以是: - 一个类(例如
RandomCrop、ColorJitter等) - 一个函数(例如
Lambda)
但无论是哪种,它们都必须是 callable —— 也就是: - 你能用
t(img)的形式调用它(类需要实现__call__(),函数本身可直接调用
class RandomCrop: # 可以是类
def __call__(self, img):
...
然后由Compose顺序执行:
class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img) # 挨个执行
return img
只要你写的类能被调用(实现了 __call__ 方法),它就可以和官方 transforms 一样被 Compose 调用。
⭐ transforms 本质上只是:img → img 的一个函数
所有 transform 都必须满足:
-
输入只有一个参数:
img -
输出也必须是
img -
输入输出格式要与 pipeline 上下游一致(非常重要)
这就是需要遵守的 transform 协议(protocol)。 -
2.自定义Transforms的步骤
- 自定义一个类YourTransforms,结构类似Compose类
__init__函数作为多参数传入的地方__call__函数具体实现自定义的transforms方法
class YourTransform(object):
def __init__(self, ...): # ...是要传入的多个参数
"""
对多参数进行传入
如 self.p = p 传入概率
......
"""
def __call__(self, img):
"""
该自定义transforms方法的具体实现过程
处理输入的图像,必须返回一张同类型的图像
"""
return img
- 3.自定义transforms实例:椒盐噪声
import torch
from torchvision.transforms import functional as F
class SaltPepperNoise(object):
"""更适合训练的椒盐噪声(Tensor版)
prob: 噪声发生的概率(像素级别)
"""
def __init__(self, prob=0.02):
self.prob = prob
def __call__(self, img):
# 保证输入最终是 Tensor
img = F.to_tensor(img) if not isinstance(img, torch.Tensor) else img.clone()
# 生成和img形状一样的随机矩阵,每个像素 ∈ [0, 1] ,为每个像素生成的随机“掩码”
noise = torch.rand_like(img)
# 盐噪声白点(像素值 = 1.0),noise范围:[0, prob/2)
img[noise < self.prob / 2] = 1.0
# 椒噪声(黑点像素值 = 0.0),noise范围:[prob/2, prob)
img[(noise >= self.prob / 2) & (noise < self.prob)] = 0.0
# noise范围:[prob, 1],不变
return img
拿树叶分类竞赛数据集写一个示例,CustomDataset代码见前面笔记。
data_transform = transforms.Compose([
SaltPepperNoise(prob=0.1)
])
train_csv_file = 'D:/tmp\A-tmp\models\classify-leaves/train.csv'
dataset = CustomDataset(train_csv_file, data_transform=data_transform)
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
# 取一个样本看看噪声
img, label = dataset[11]
# Tensor 转 PIL 才能显示(ToTensor 是 0~1)
img_show = F.to_pil_image(img)
plt.figure(figsize=(5,5))
plt.imshow(img_show)
plt.title(f"Label: {label}")
plt.axis("off")
plt.show()

最终这里也给出一个Numpy版本的椒盐噪声添加,读者可自行学习理解。
class AddPepperNoise(object):
"""增加椒盐噪声(Numpy版)
snr (float): Signal Noise Rate
p (float): 概率值,依概率执行该操作
"""
def __init__(self, snr, p=0.9):
assert isinstance(snr, float) and (isinstance(p, float))
self.snr = snr
self.p = p
def __call__(self, img):
"""
输入参数 img: PIL Image
返回也是 PIL image
"""
if random.uniform(0, 1) < self.p:
img_ = np.array(img).copy()
h, w, c = img_.shape
signal_pct = self.snr
noise_pct = (1 - self.snr)
mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[signal_pct, noise_pct/2., noise_pct/2.])
mask = np.repeat(mask, c, axis=2)
img_[mask == 1] = 255 # 盐噪声
img_[mask == 2] = 0 # 椒噪声
return Image.fromarray(img_.astype('uint8')).convert('RGB')
else:
return img
2173

被折叠的 条评论
为什么被折叠?



