一、transforms 的核心结构
1. 模块定位
-
作用:专门用于图像数据的预处理和数据增强
-
位置:
torchvision.transforms
模块 -
输入类型:支持 PIL Image、Tensor、numpy array(部分操作)
2. 核心组件
from torchvision import transforms
# 基础结构
transform_pipeline = transforms.Compose([
transforms.Resize(256), # 调整尺寸
transforms.RandomCrop(224), # 随机裁剪
transforms.ToTensor(), # 转为张量
transforms.Normalize( # 标准化
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
二、transforms 的六大功能类别
1. 几何变换(空间操作)
变换名称 | 功能说明 | 关键参数示例 |
---|---|---|
Resize | 调整图像尺寸 | size=(256,256) 或 256 |
RandomCrop | 随机裁剪 | size=224 |
CenterCrop | 中心裁剪 | size=224 |
RandomResizedCrop | 随机缩放裁剪 | size=224, scale=(0.08,1.0) |
RandomRotation | 随机旋转 | degrees=30 |
RandomHorizontalFlip | 随机水平翻转(概率默认0.5) | p=0.5 |
RandomVerticalFlip | 随机垂直翻转 | p=0.5 |
示例:
transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), # 随机缩放裁剪
transforms.RandomRotation(15), # ±15度随机旋转
transforms.RandomHorizontalFlip(p=0.7) # 70%概率水平翻转
])
2. 颜色变换
变换名称 | 功能说明 | 关键参数示例 |
---|---|---|
ColorJitter | 随机调整亮度/对比度/饱和度/色相 | brightness=0.2, contrast=0.2 |
RandomGrayscale | 随机灰度化(概率默认0.1) | p=0.1 |
RandomAdjustSharpness | 随机调整锐度 | sharpness_factor=2 |
GaussianBlur | 高斯模糊 | kernel_size=(5,5) |
RandomAutocontrast | 自动对比度调整 | - |
示例:
transforms.ColorJitter(
brightness=0.3, # 亮度变化幅度
contrast=0.3, # 对比度变化幅度
saturation=0.3, # 饱和度变化幅度
hue=0.1 # 色相变化范围(-hue, +hue)
)
3. 张量转换与标准化
变换名称 | 功能说明 |
---|---|
ToTensor | 将 PIL Image 或 numpy.ndarray 转换为 Tensor(HWC → CHW,归一化到 [0,1]) |
Normalize | 用均值和标准差标准化 |
ConvertImageDtype | 转换图像数据类型(如 float32 → uint8) |
标准化原理:
# 计算公式
normalized_image = (image - mean) / std
# 典型ImageNet参数
transforms.Normalize(
mean=[0.485, 0.456, 0.406], # R/G/B通道均值
std=[0.229, 0.224, 0.225] # R/G/B通道标准差
)
4. 特殊操作
变换名称 | 功能说明 |
---|---|
RandomErasing | 随机区域擦除(Cutout) |
RandomPerspective | 随机透视变换 |
FiveCrop | 生成四个角+中心共5个裁剪 |
TenCrop | 生成5个裁剪+水平翻转共10个样本 |
RandomErasing 示例:
transforms.RandomErasing(
p=0.5, # 应用概率
scale=(0.02, 0.33), # 擦除区域占比范围
ratio=(0.3, 3.3), # 宽高比范围
value=0 # 填充值(或随机值)
)
5. 组合与复用
方法 | 功能说明 |
---|---|
Compose | 顺序执行多个变换 |
RandomApply | 随机选择是否应用一组变换 |
RandomChoice | 随机选择执行一个变换 |
RandomOrder | 随机打乱变换顺序 |
高级组合示例:
transforms.RandomApply([
transforms.RandomRotation(30),
transforms.ColorJitter()
], p=0.5) # 50%概率应用旋转和颜色抖动
6. 自定义变换
class MyTransform(object):
def __init__(self, param):
self.param = param
def __call__(self, img):
# 实现自定义变换逻辑
return transformed_img
# 使用方式
transforms.Compose([
MyTransform(0.5),
transforms.ToTensor()
])
三、transforms 的三大使用场景
场景1:训练集与验证集的不同处理
# 训练集增强流程
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 验证集处理流程
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
场景2:多模态数据处理
# 同时处理图像和深度图
dual_transform = transforms.Compose([
transforms.Resize(256),
transforms.Lambda(
lambda x: (x[0], x[1]) if isinstance(x, tuple) else x
)
])
场景3:时序数据增强
# 视频帧处理
video_transform = transforms.Compose([
transforms.Resize((112, 112)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.5, 0.5, 0.5),
transforms.ToTensor()
])
四、最佳实践与调试技巧
1. 可视化检查
def visualize_transform(dataset, idx=0, n_col=5):
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 5))
for i in range(n_col):
img, _ = dataset[idx]
plt.subplot(1, n_col, i+1)
plt.imshow(img.permute(1, 2, 0).numpy())
plt.show()
# 使用示例
visualize_transform(train_dataset)
2. 参数优化策略
-
动态增强:随着训练轮数增加逐步减少增强强度
-
自动增强:使用
torchvision.transforms.AutoAugment
transforms.AutoAugment(
policy=transforms.AutoAugmentPolicy.IMAGENET # 使用ImageNet策略
)
3. 性能优化
-
预处理缓存:对固定变换结果进行缓存
-
GPU加速:使用
kornia
库实现GPU端数据增强
import kornia.augmentation as K
gpu_aug = K.AugmentationSequential(
K.RandomHorizontalFlip(p=0.5),
K.RandomGaussianBlur(kernel_size=(3,3), sigma=(0.1, 2.0), p=0.5),
data_keys=["input"]
)
4. 常见错误排查
错误现象 | 解决方案 |
---|---|
图像尺寸不匹配 | 检查 Resize 和 Crop 参数 |
数值范围异常 | 确认 ToTensor 和 Normalize 顺序 |
颜色通道错误 | 确保使用 .convert('RGB') |
张量维度错误 | 检查是否遗漏 ToTensor |