计算数据集的mean和std

文章介绍了在图像处理时如何进行数据集级别的normalize,提供了两种方法。第一种针对小数据集,通过一次性加载所有数据计算mean和std;第二种适用于大数据集,通过迭代批次计算平均值和标准差,以避免内存溢出问题。代码示例中使用了PyTorch库进行操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

图像处理时,通常需要进行normalize,最常用的是dataset level的normalize。计算数据集的meanstd的方法为:

1. 对于小数据

import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets


# default format: [b, c, h, w]
def mean_std_small(loader):
    images, labels = next(iter(loader))
    mean, std = images.mean([0, 2, 3]), images.std([0, 2, 3])

    return mean, std


data_path = "./dataset/"

transform_img = transforms.Compose([
    transforms.ToTensor(),
])

image_data = datasets.CIFAR10(root='dataset/', train=True,
                                 transform=transforms.ToTensor(),
                                 download=True)

image_data_loader = DataLoader(
    image_data,
    batch_size=len(image_data),
    shuffle=False,
    num_workers=1
)

mean, std = mean_std_small(image_data_loader)
print("mean and std: ", mean, std)

2. 对于大的数据集

import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets

def batch_mean_std(loader):
    cnt = 0
    fst_moment = torch.empty(3)
    snd_moment = torch.empty(3)

    for images, _ in loader:
        b, c, h, w = images.shape
        nb_pixels = b*h*w
        sum_ = torch.sum(images, dim=[0, 2, 3])
        sum_of_square = torch.sum(images ** 2, dim=[0, 2, 3])

        fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)
        cnt += nb_pixels

    mean, std = fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)

    return mean, std


data_path = "./dataset/"

transform_img = transforms.Compose([
    transforms.ToTensor(),
])

#image_data = datasets.CIFAR10(root='dataset/', train=True,
# 							   transform=transforms.ToTensor(),
#                              download=True)

image_data = torchvision.datasets.ImageFolder(
    root=data_path, transform=transform_img
)

image_data_loader = DataLoader(
    image_data,
    batch_size=len(image_data),
    shuffle=False,
    num_workers=0
)

batch_size = 64
loader = DataLoader(image_data, batch_size=batch_size,
                    num_workers=1)


mean, std = batch_mean_std(loader)
print("mean and std: ", mean, std)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值