torchvision.utils.save_image保存图片全黑问题

本文介绍了解决VAE模型中使用torchvision.utils.save_image保存图片出现全黑问题的过程,包括排查模型结构、调整学习率等步骤,并分享了最终解决办法。

torchvision.utils.save_image保存图片全黑

问题描述:跑VAE模型的时候,遇到利用torchvision.utils.save_image保存图片,结果保存的图片是全黑的,而且图片是由灰色慢慢变黑的
在这里插入图片描述原始图片像素值归一化以后的数据如下
在这里插入图片描述
重构出来的图片的部分数据如下(很正常啊,归一化以后的数据)
在这里插入图片描述
尝试
1、刚开始以为是模型结构的问题,但是看了好长时间都没发现有什么问题。
2、save_image函数的问题,于是改用opencv的imwrite函数来保存图片,结果还是全黑。
3、归一化问题:由于重构图片的数据值在0-1之间,所以保存的时候可能都按照0保存,于是每个值都*255,结果发现保存的图片还是全黑。
4、之后通过调试代码,发现每轮迭代的损失值变化很小很小,有的甚至没有变化,于是想到是不是参数没有回传,联想到了优化器的学习率问题。(学习率过低会导致学习速度太慢,学习率过高又容易导致难以收敛),刚开始学习率是1e-3=0.001,修改为0.0005,发现生成的图片不是全黑啦!
在这里插入图片描述
虽然全黑的问题解决了,但是重构图片的质量并不好,所以还是需要再修改的!

import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset, random_split from torchvision import transforms, models import matplotlib.pyplot as plt from PIL import Image import os import numpy as np from tqdm import tqdm import argparse # 数据预处理和增强 class GarbageDataset(Dataset): def __init__(self, data_dir, transform=None): self.data_dir = data_dir self.transform = transform self.classes = sorted(os.listdir('D:/rengongznshijian/train/data')) self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)} self.images = [] self.labels = [] # 加载图像和标签 for class_name in self.classes: class_dir = os.path.join(data_dir, class_name) if os.path.isdir(class_dir): for img_name in os.listdir(class_dir): if img_name.lower().endswith(('.png', '.jpg', '.jpeg')): self.images.append(os.path.join(class_dir, img_name)) self.labels.append(self.class_to_idx[class_name]) def __len__(self): return len(self.images) def __getitem__(self, idx): img_path = self.images[idx] label = self.labels[idx] # 加载图像 image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) return image, label # 定义残差网络模型 class ResNetGarbageClassifier(nn.Module): def __init__(self, num_classes, model_name='resnet50', pretrained=True): super(ResNetGarbageClassifier, self).__init__() if model_name == 'resnet18': self.backbone = models.resnet18(pretrained=pretrained) elif model_name == 'resnet34': self.backbone = models.resnet34(pretrained=pretrained) elif model_name == 'resnet50': self.backbone = models.resnet50(pretrained=pretrained) elif model_name == 'resnet101': self.backbone = models.resnet101(pretrained=pretrained) else: raise ValueError(f"Unsupported model: {model_name}") # 获取最后一层的输入特征数 num_features = self.backbone.fc.in_features # 替换最后的连接层 self.backbone.fc = nn.Sequential( nn.Dropout(0.5), nn.Linear(num_features, 512), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(512, num_classes) ) def forward(self, x): return self.backbone(x) # 训练函数 def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device): train_losses = [] train_accs = [] val_losses = [] val_accs = [] best_acc = 0.0 best_model_wts = None for epoch in range(num_epochs): print(f'Epoch {epoch+1}/{num_epochs}') print('-' * 50) # 训练阶段 model.train() running_loss = 0.0 running_corrects = 0 train_bar = tqdm(train_loader, desc='Training') for inputs, labels in train_bar: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) train_bar.set_postfix({ 'Loss': f'{loss.item():.4f}', 'Acc': f'{torch.sum(preds == labels.data).double() / inputs.size(0):.4f}' }) epoch_loss = running_loss / len(train_loader.dataset) epoch_acc = running_corrects.double() / len(train_loader.dataset) train_losses.append(epoch_loss) train_accs.append(epoch_acc.cpu()) print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') # 验证阶段 model.eval() val_running_loss = 0.0 val_running_corrects = 0 with torch.no_grad(): val_bar = tqdm(val_loader, desc='Validation') for inputs, labels in val_bar: inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) val_running_loss += loss.item() * inputs.size(0) val_running_corrects += torch.sum(preds == labels.data) val_bar.set_postfix({ 'Loss': f'{loss.item():.4f}', 'Acc': f'{torch.sum(preds == labels.data).double() / inputs.size(0):.4f}' }) val_epoch_loss = val_running_loss / len(val_loader.dataset) val_epoch_acc = val_running_corrects.double() / len(val_loader.dataset) val_losses.append(val_epoch_loss) val_accs.append(val_epoch_acc.cpu()) print(f'Val Loss: {val_epoch_loss:.4f} Acc: {val_epoch_acc:.4f}') # 学习率调度 if scheduler: scheduler.step() # 保存最佳模型 if val_epoch_acc > best_acc: best_acc = val_epoch_acc best_model_wts = model.state_dict().copy() torch.save(best_model_wts, 'best_garbage_model.pth') print(f'New best model saved with accuracy: {best_acc:.4f}') print() # 加载最佳模型权重 model.load_state_dict(best_model_wts) return model, { 'train_losses': train_losses, 'train_accs': train_accs, 'val_losses': val_losses, 'val_accs': val_accs } # 可视化训练过程 def plot_training_history(history): fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) # 损失曲线 ax1.plot(history['train_losses'], label='Train Loss') ax1.plot(history['val_losses'], label='Val Loss') ax1.set_title('Training and Validation Loss') ax1.set_xlabel('Epoch') ax1.set_ylabel('Loss') ax1.legend() # 准确率曲线 ax2.plot(history['train_accs'], label='Train Acc') ax2.plot(history['val_accs'], label='Val Acc') ax2.set_title('Training and Validation Accuracy') ax2.set_xlabel('Epoch') ax2.set_ylabel('Accuracy') ax2.legend() plt.tight_layout() plt.savefig('training_history.png', dpi=300, bbox_inches='tight') plt.show() # 评估模型 def evaluate_model(model, test_loader, device, class_names): model.eval() correct = 0 total = 0 class_correct = list(0. for _ in range(len(class_names))) class_total = list(0. for _ in range(len(class_names))) with torch.no_grad(): for inputs, labels in tqdm(test_loader, desc='Evaluating'): inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() c = (predicted == labels).squeeze() for i in range(labels.size(0)): label = labels[i] class_correct[label] += c[i].item() class_total[label] += 1 print(f'Overall Accuracy: {100 * correct / total:.2f}%') print('\nPer-class Accuracy:') for i in range(len(class_names)): if class_total[i] > 0: print(f'{class_names[i]}: {100 * class_correct[i] / class_total[i]:.2f}%') return 100 * correct / total def main(): parser = argparse.ArgumentParser(description='Garbage Classification Training') parser.add_argument('--data_dir', type=str, default='garbage_dataset', help='Dataset directory') parser.add_argument('--batch_size', type=int, default=32, help='Batch size') parser.add_argument('--epochs', type=int, default=50, help='Number of epochs') parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') parser.add_argument('--model', type=str, default='resnet50', help='Model name') parser.add_argument('--train_ratio', type=float, default=0.8, help='Training set ratio') parser.add_argument('--val_ratio', type=float, default=0.1, help='Validation set ratio') args = parser.parse_args() # 配置参数 config = { 'data_dir': args.data_dir, 'batch_size': args.batch_size, 'num_epochs': args.epochs, 'learning_rate': args.lr, 'model_name': args.model, 'pretrained': True, 'num_workers': 4, 'train_ratio': args.train_ratio, 'val_ratio': args.val_ratio, 'test_ratio': 1 - args.train_ratio - args.val_ratio } # 设备配置 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'Using device: {device}') # 数据预处理 train_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(10), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 加载数据集 full_dataset = GarbageDataset(config['data_dir'], transform=train_transform) # 获取类别信息 class_names = full_dataset.classes num_classes = len(class_names) print(f'Number of classes: {num_classes}') print(f'Class names: {class_names}') # 分割数据集 dataset_size = len(full_dataset) train_size = int(config['train_ratio'] * dataset_size) val_size = int(config['val_ratio'] * dataset_size) test_size = dataset_size - train_size - val_size # 检查每个数据集的样本数是否为 0 train_dataset, val_dataset, test_dataset = random_split( full_dataset, [train_size, val_size, test_size] ) # 验证和测试集使用验证时的transform val_dataset.dataset.transform = val_transform test_dataset.dataset.transform = val_transform # 创建数据加载器 train_loader = DataLoader( train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'] ) val_loader = DataLoader( val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'] ) test_loader = DataLoader( test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'] ) print(f'Train samples: {len(train_dataset)}') print(f'Val samples: {len(val_dataset)}') print(f'Test samples: {len(test_dataset)}') # 创建模型 model = ResNetGarbageClassifier( num_classes=num_classes, model_name=config['model_name'], pretrained=config['pretrained'] ) model = model.to(device) # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=config['learning_rate']) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) # 训练模型 print('Starting training...') model, history = train_model( model, train_loader, val_loader, criterion, optimizer, scheduler, config['num_epochs'], device ) # 可视化训练过程 plot_training_history(history) # 评估模型 print('Evaluating on test set...') test_accuracy = evaluate_model(model, test_loader, device, class_names) # 保存最终模型 torch.save({ 'model_state_dict': model.state_dict(), 'class_names': class_names, 'config': config }, 'final_garbage_model.pth') print(f'Test accuracy: {test_accuracy:.2f}%') print('Training completed!') if __name__ == '__main__': main()
最新发布
11-25
``` class SimpleCNN(nn.Module): def __init__(self, num_classes): super().__init__() # 原始特征提取层 self.features = nn.Sequential( nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2) ) # 原始分类器 self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(128*32*32, 512), nn.ReLU(), nn.Linear(512, num_classes) ) # 新增反卷积网络(解码器) self.decoder = nn.Sequential( nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), nn.ReLU(), nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), nn.ReLU(), nn.ConvTranspose2d(32, 3, 3, stride=2, padding=1, output_padding=1), nn.Sigmoid() # 输出像素值在[0,1]之间 ) def forward(self, x): x = self.features(x) return self.classifier(x) def visualize_features(self, x): # 前向传播获取特征图 features = self.features(x) # 通过反卷积重建图像 return self.decoder(features) for epoch in range(num_epochs): # 训练阶段 model.train() train_loss = 0.0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() train_loss += loss.item() * images.size(0) # 验证阶段 model.eval() val_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) val_loss += loss.item() * images.size(0) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() # 计算指标 train_loss = train_loss / len(train_dataset) val_loss = val_loss / len(val_dataset) val_acc = 100 * correct / total # 保存最佳模型 if val_acc > best_val_acc: best_val_acc = val_acc torch.save(model.state_dict(), 'best_model.pth') print(f'Epoch [{epoch+1}/{num_epochs}] | ' f'Train Loss: {train_loss:.4f} | ' f'Val Loss: {val_loss:.4f} | ' f'Val Acc: {val_acc:.2f}%') # 在训练循环内(epoch循环结束后)添加: if (epoch+1) % 5 == 0: # 每5个epoch可视化一次 model.eval() with torch.no_grad(): # 获取验证集样本 sample_data, _ = next(iter(val_loader)) sample_data = sample_data.to(device) # 原始图像 save_image(sample_data, f'epoch_{epoch+1}_original.png') # 重建图像 reconstructed = model.visualize_features(sample_data) save_image(reconstructed, f'epoch_{epoch+1}_reconstructed.png')```原始图像和重建图像保存到哪里了,如何查看
03-27
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值