一、创建自定义Dataset类:(命名为mydataset.py)
import os
from PIL import Image
from torch.utils.data import Dataset
class CustomImageDataset(Dataset):
def __init__(self, image_dir, transform=None): # 添加transform参数
self.image_dir = image_dir # 添加image_dir属性
self.transform = transform # 添加transform属性
self.image_names = os.listdir(image_dir)# 使用os.listdir获取图片名
def __len__(self): # 返回图片数量
return len(self.image_names)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.image_names[idx]) # 获取图片路径
image = Image.open(img_path).convert("RGB") # 读取图片并转换为RGB格式
if self.transform:
image = self.transform(image) # 使用transform对图片进行处理
return image
二、实例化自定义类并在Tensorboard中可视化:
from mydataset import CustomImageDataset
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
# 使用示例
image_dir = '/home/yyy/code_dataset/train_1'
transform = transforms.Compose([
transforms.Resize((256, 256)), # 确保图像大小一致
transforms.ToTensor(), # 将图像转换为张量
# transforms.Normalize(mean=(0.4347, 0.4335, 0.3502), std=(0.1702, 0.1540, 0.1660)) # 常见的标准化
])
dataset = CustomImageDataset(image_dir=image_dir, transform=transform)
# 创建 SummaryWriter
log_dir = 'runs/image_example'
writer = SummaryWriter(log_dir)
# 将数据集中的图像添加到 TensorBoard
for i in range(min(len(dataset), 10)): # 仅展示前10张图像
image = dataset[i]
writer.add_image(f'Image_{i}', image, i)
# 关闭 SummaryWriter
writer.close()
打开另一个终端,在终端中输入以下命令打开tensorboard:
tensorboard --logdir /home/yyy/code_dataset/mycode/runs/image_example --load_fast true
结果如下图:
点击蓝色链接,浏览器打开tensorboard,观察可视化图片:
三、运用该实例计算需要初始化均值和标准差的值
import torch
from mydataset import CustomImageDataset
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
# 使用自定义数据集
root_dir = '/home/yyy/code_dataset/train_1'
transform = transforms.Compose([
transforms.ToTensor()
])
dataset = CustomImageDataset(image_dir=root_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=100, shuffle=False, num_workers=4)
# 接下来可以使用这个 dataloader 进行均值和标准差的计算
# 初始化均值和标准差
mean = torch.zeros(3)
std = torch.zeros(3)
# 累加所有图像的总像素值
for images in tqdm(dataloader, desc='Computing mean and std'):
# 如果 images 是 3 维张量,使用 unsqueeze 在第 0 维添加一个维度
if images.dim() == 3:
images = images.unsqueeze(0)
# 累加均值和标准差
for i in range(3):
mean[i] += images[:, i, :, :].mean() * images.size(0)
std[i] += images[:, i, :, :].std() * images.size(0)
# 除以总的图像数量来获取最终的均值和标准差
mean /= len(dataloader.dataset)
std /= len(dataloader.dataset)
print(f'Mean: {mean}')
print(f'Std: {std}')
计算结果如下: