图像处理时,通常需要进行normalize
,最常用的是dataset level的normalize。计算数据集的mean
和std
的方法为:
1. 对于小数据
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
# default format: [b, c, h, w]
def mean_std_small(loader):
images, labels = next(iter(loader))
mean, std = images.mean([0, 2, 3]), images.std([0, 2, 3])
return mean, std
data_path = "./dataset/"
transform_img = transforms.Compose([
transforms.ToTensor(),
])
image_data = datasets.CIFAR10(root='dataset/', train=True,
transform=transforms.ToTensor(),
download=True)
image_data_loader = DataLoader(
image_data,
batch_size=len(image_data),
shuffle=False,
num_workers=1
)
mean, std = mean_std_small(image_data_loader)
print("mean and std: ", mean, std)
2. 对于大的数据集
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
def batch_mean_std(loader):
cnt = 0
fst_moment = torch.empty(3)
snd_moment = torch.empty(3)
for images, _ in loader:
b, c, h, w = images.shape
nb_pixels = b*h*w
sum_ = torch.sum(images, dim=[0, 2, 3])
sum_of_square = torch.sum(images ** 2, dim=[0, 2, 3])
fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)
cnt += nb_pixels
mean, std = fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)
return mean, std
data_path = "./dataset/"
transform_img = transforms.Compose([
transforms.ToTensor(),
])
#image_data = datasets.CIFAR10(root='dataset/', train=True,
# transform=transforms.ToTensor(),
# download=True)
image_data = torchvision.datasets.ImageFolder(
root=data_path, transform=transform_img
)
image_data_loader = DataLoader(
image_data,
batch_size=len(image_data),
shuffle=False,
num_workers=0
)
batch_size = 64
loader = DataLoader(image_data, batch_size=batch_size,
num_workers=1)
mean, std = batch_mean_std(loader)
print("mean and std: ", mean, std)