transforms.Normalize()
是 PyTorch 中用于图像数据标准化的函数,通过对每个通道执行减均值并除以标准差的操作,使数据分布更稳定。以下是关键解析:参数解析
mean (序列): 各通道的均值。
std (序列): 各通道的标准差。
例如,对 RGB 图像常用 ImageNet 的统计值:
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
。数学公式
对于每个通道,计算方式为
假设输入张量范围为 [0, 1],处理后数据分布接近均值为 0、标准差为 1。
使用要点
1. 顺序与转换
应在ToTensor()
之后应用,因ToTensor()
将图像转换为 [C, H, W] 形状的张量,并将像素值缩放到 [0, 1]。transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) # 单通道示例 ])
2. 数据一致性
使用预训练模型时,须与训练时的均值和标准差一致,否则可能影响性能。3. 自定义数据集
需自行计算各通道的均值和标准差。方法如下:
遍历数据集,累加各通道像素值及平方。
计算均值:总和 / 总像素数。
计算标准差:
4.错误处理
通道数不匹配会报错(如 3 通道图像使用 2 个均值)。
确保输入张量已正确缩放到 [0, 1](通过
ToTensor()
)。示例效果
单通道,
mean=0.5
,std=0.5
:
输入 0.5 → 输出 0,1.0 → 1.0,0 → -1.0,范围变为 [-1, 1]。注意事项
避免重复应用:多次归一化会破坏数据分布。
总结
transforms.Normalize()
通过标准化提升模型训练效果,需确保参数与数据特性及模型要求一致,正确应用于预处理流程中。设备支持:自动处理张量所在设备(CPU/GPU)。
01-13
1433

09-15
814

03-22
2万+
