torchvision.utils.save_image 讲解

文章讲述了在使用PyTorch的`torchvision.utils.save_image`函数保存图像时,normalize参数的不同设置如何影响图像的显示和读取。当normalize为False时,保存的是包含三个通道的伪灰度图像;而normalize为True则会进行归一化处理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

因为这个让我犯过错误,我记录一下。
其他参数百度即可。
主要讲解normalize

a = torch.tensor([0.8, 0.9])
a = a.unsqueeze(0)
save_path = f"111.png"
torchvision.utils.save_image(a, save_path, normalize=False, cmap='gray')

结果:
在这里插入图片描述

normalize为False时,使用Image加载图片

# 使用 Image.open 加载图像
image = Image.open('111.png')
a = np.array(image)

结果:
在这里插入图片描述
发现其实它保存的是三个通道的伪灰度图像。
每个像素值通过0.8x255 0.9x255得来的。

换一个方法读取图片,只希望读取单通道。

在这里插入图片描述

注意
如果将normalize设置为True
他会根据最大值 最小值的归一化方法将其进行线性变化。
0.8变为最小的0 0.9变为最大的1
在这里插入图片描述
Y=(x-最小)/(最大-最小)
1 = (0.9 - 0.8) / (0.9 -0.8)
0 = (0.8-0.8)/ (0.9-0.8)


在这里插入图片描述

Normalize为False时,超过1的被截取为1了。

小于0的被截取为0了。

由于任务较为复杂,我们将代码分成几部分来讲解。 ## 1. 数据预处理 首先,我们需要下载 MNIST 数据集并进行预处理。 ```python import torch from torchvision import datasets, transforms # 定义数据预处理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # 下载训练集 train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform) # 加载训练集 train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True) ``` ## 2. 构建模型 接下来,我们需要构建 Diffusion Model。 ```python import torch.nn as nn class DiffusionModel(nn.Module): def __init__(self, input_size=784, hidden_size=256, output_size=784, num_layers=2): super(DiffusionModel, self).__init__() # 编码器 encoder_layers = [] for i in range(num_layers): if i == 0: encoder_layers.append(nn.Linear(input_size, hidden_size)) else: encoder_layers.append(nn.Linear(hidden_size, hidden_size)) encoder_layers.append(nn.ReLU()) self.encoder = nn.Sequential(*encoder_layers) # 解码器 decoder_layers = [] for i in range(num_layers): if i == 0: decoder_layers.append(nn.Linear(hidden_size, output_size)) else: decoder_layers.append(nn.Linear(hidden_size, hidden_size)) decoder_layers.append(nn.ReLU()) decoder_layers.append(nn.Linear(hidden_size, input_size)) decoder_layers.append(nn.Tanh()) self.decoder = nn.Sequential(*decoder_layers) def forward(self, x): z = self.encoder(x) y = self.decoder(z) return y ``` ## 3. 训练模型 现在,我们可以开始训练模型了。 ```python import time import matplotlib.pyplot as plt # 定义模型和优化器 model = DiffusionModel() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = nn.MSELoss() # 训练模型 num_epochs = 50 losses = [] start_time = time.time() for epoch in range(num_epochs): for batch_idx, (data, _) in enumerate(train_loader): data = data.view(data.shape[0], -1) optimizer.zero_grad() recon_data = model(data) loss = criterion(recon_data, data) loss.backward() optimizer.step() losses.append(loss.item()) print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item())) # 打印训练时间 end_time = time.time() print('Training Time: {:.2f}s'.format(end_time - start_time)) # 可视化损失函数 plt.plot(losses) plt.xlabel('Iterations') plt.ylabel('Loss') plt.show() ``` ## 4. 保存模型和样本 最后,我们可以保存模型和生成手写数字样本。 ```python import os import torchvision.utils as vutils # 创建目录 os.makedirs('images', exist_ok=True) # 保存模型 torch.save(model.state_dict(), 'diffusion_model.pth') # 生成样本 num_samples = 64 z = torch.randn(num_samples, 256) samples = model.decoder(z) samples = samples.view(num_samples, 1, 28, 28) vutils.save_image(samples, 'images/samples.png', normalize=True, nrow=8) ``` 现在,你可以在 images/samples.png 中查看生成的手写数字样本了。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值