Pytorch深入浅出(二)之数据预处理Transforms(下)

自定义Transforms方法

试想一下,在使用 torchvision.transforms 做图像增强时,我们经常会遇到:官方 API 提供不了的操作(例如添加椒盐噪声、局部模糊、自定义颜色扰动等),这种情况下,就需要我们自定义——transforms。即自定义transforms类并“以参数形式”传到 Dataset 的__init__初始化里。
实践中最推荐的方式是:
👉 写一个和官方 Transform 完全相同接口的类 (__call__)
👉 再通过 Compose() 与其他 transform 串联使用
这样可读性好、扩展性强,也能保持项目代码风格一致。

  • 1.transforms的底层核心思想
    官网 transforms 看上去很多,但其实底层哲学只有一句:一个可调用对象列表
    可以是:
  • 一个类(例如 RandomCropColorJitter等)
  • 一个函数(例如 Lambda
    但无论是哪种,它们都必须是 callable —— 也就是:
  • 你能用 t(img) 的形式调用它(类需要实现 __call__(),函数本身可直接调用
class RandomCrop:        # 可以是类
    def __call__(self, img):
        ...

然后由Compose顺序执行:

class Compose(object):
    """Composes several transforms together.
    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.
    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """
    def __init__(self, transforms):
        self.transforms = transforms
    def __call__(self, img):
        for t in self.transforms:
            img = t(img)  # 挨个执行
        return img

只要你写的类能被调用(实现了 __call__ 方法),它就可以和官方 transforms 一样被 Compose 调用。
⭐ transforms 本质上只是:img → img 的一个函数
所有 transform 都必须满足:

  • 输入只有一个参数:img

  • 输出也必须是 img

  • 输入输出格式要与 pipeline 上下游一致(非常重要)
    这就是需要遵守的 transform 协议(protocol)。

  • 2.自定义Transforms的步骤

  1. 自定义一个类YourTransforms,结构类似Compose类
  2. __init__函数作为多参数传入的地方
  3. __call__函数具体实现自定义的transforms方法
class YourTransform(object):
    def __init__(self, ...):  # ...是要传入的多个参数
        """
        对多参数进行传入 
        如 self.p = p 传入概率 
        ......
        """

    def __call__(self, img):
        """
        该自定义transforms方法的具体实现过程
        处理输入的图像,必须返回一张同类型的图像
        """
        return img
  • 3.自定义transforms实例:椒盐噪声
import torch
from torchvision.transforms import functional as F

class SaltPepperNoise(object):
    """更适合训练的椒盐噪声(Tensor版)
    prob: 噪声发生的概率(像素级别)
    """
    def __init__(self, prob=0.02):
        self.prob = prob

    def __call__(self, img):
        # 保证输入最终是 Tensor
        img = F.to_tensor(img) if not isinstance(img, torch.Tensor) else img.clone()
        # 生成和img形状一样的随机矩阵,每个像素 ∈ [0, 1] ,为每个像素生成的随机“掩码”
        noise = torch.rand_like(img)
        # 盐噪声白点(像素值 = 1.0),noise范围:[0, prob/2)
        img[noise < self.prob / 2] = 1.0
        # 椒噪声(黑点像素值 = 0.0),noise范围:[prob/2, prob)
        img[(noise >= self.prob / 2) & (noise < self.prob)] = 0.0
        # noise范围:[prob, 1],不变
        return img

拿树叶分类竞赛数据集写一个示例,CustomDataset代码见前面笔记。

data_transform = transforms.Compose([
		SaltPepperNoise(prob=0.1)
])
  
train_csv_file = 'D:/tmp\A-tmp\models\classify-leaves/train.csv'
dataset = CustomDataset(train_csv_file, data_transform=data_transform)

import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
  
# 取一个样本看看噪声
img, label = dataset[11]
  
# Tensor 转 PIL 才能显示(ToTensor 是 0~1)
img_show = F.to_pil_image(img)
  
plt.figure(figsize=(5,5))
plt.imshow(img_show)
plt.title(f"Label: {label}")
plt.axis("off")
plt.show()

在这里插入图片描述

最终这里也给出一个Numpy版本的椒盐噪声添加,读者可自行学习理解。

class AddPepperNoise(object):
    """增加椒盐噪声(Numpy版)
    snr (float): Signal Noise Rate
    p (float): 概率值,依概率执行该操作
    """
    def __init__(self, snr, p=0.9):
        assert isinstance(snr, float) and (isinstance(p, float))
        self.snr = snr
        self.p = p

    def __call__(self, img):
        """
        输入参数 img: PIL Image
        返回也是 PIL image
        """
        if random.uniform(0, 1) < self.p:
            img_ = np.array(img).copy()
            h, w, c = img_.shape
            signal_pct = self.snr
            noise_pct = (1 - self.snr)
            mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[signal_pct, noise_pct/2., noise_pct/2.])
            mask = np.repeat(mask, c, axis=2)
            img_[mask == 1] = 255   # 盐噪声
            img_[mask == 2] = 0     # 椒噪声
            return Image.fromarray(img_.astype('uint8')).convert('RGB')
        else:
            return img
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值