import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Sampler
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
import numpy as np
from tqdm import tqdm
from collections import defaultdict
# 超参数设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
image_size = 64
channels = 1 # 修改为灰度图像
latent_dim = 128
learning_rate = 1e-3
epochs = 130
# 数据加载与预处理(添加灰度转换)
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.Grayscale(num_output_channels=1), # 添加灰度转换
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # 单通道归一化
])
data_dir = r'I:\codes\新'
dataset = datasets.ImageFolder(root=data_dir, transform=transform)
# 分层采样器(保持不变)
class StratifiedSampler(Sampler):
def __init__(self, labels, batch_size):
self.labels = labels
self.batch_size = batch_size
self.class_indices = defaultdict(list)
for i, label in enumerate(labels):
self.class_indices[label].append(i)
self.num_classes = len(self.class_indices)
self.samples_per_class = batch_size // self.num_classes
def __iter__(self):
indices = []
for _ in range(len(self.labels) // self.batch_size):
for class_idx in range(self.num_classes):
indices.extend(np.random.choice(self.class_indices[class_idx], self.samples_per_class, replace=False))
return iter(indices)
def __len__(self):
return len(self.labels)
sampler = StratifiedSampler(dataset.targets, batch_size)
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
# 编码器(调整输入通道)
class Encoder(nn.Module):
def __init__(self, channels, latent_dim, num_classes):
super(Encoder, self).__init__()
self.conv1 = nn.Conv2d(channels, 32, 3, 2, 1) # 输入通道改为1
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, 3, 2, 1)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 128, 3, 2, 1)
self.bn3 = nn.BatchNorm2d(128)
self.conv4 = nn.Conv2d(128, 256, 3, 2, 1)
self.bn4 = nn.BatchNorm2d(256)
self.label_emb = nn.Embedding(num_classes, latent_dim)
self.fc_mu = nn.Linear(256 * (image_size // 16) * (image_size // 16) + latent_dim, latent_dim)
self.fc_logvar = nn.Linear(256 * (image_size // 16) * (image_size // 16) + latent_dim, latent_dim)
self.relu = nn.ReLU()
def forward(self, x, labels):
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.relu(self.bn3(self.conv3(x)))
x = self.relu(self.bn4(self.conv4(x)))
x = x.view(x.size(0), -1)
label_emb = self.label_emb(labels)
x = torch.cat([x, label_emb], dim=1)
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
return mu, logvar
# 解码器(调整输出通道)
class Decoder(nn.Module):
def __init__(self, channels, latent_dim, num_classes):
super(Decoder, self).__init__()
self.fc = nn.Linear(latent_dim + latent_dim, 256 * (image_size // 16) * (image_size // 16))
self.upconv1 = nn.ConvTranspose2d(256, 128, 2, 2)
self.bn1 = nn.BatchNorm2d(128)
self.upconv2 = nn.ConvTranspose2d(128, 64, 2, 2)
self.bn2 = nn.BatchNorm2d(64)
self.upconv3 = nn.ConvTranspose2d(64, 32, 2, 2)
self.bn3 = nn.BatchNorm2d(32)
self.upconv4 = nn.ConvTranspose2d(32, channels, 2, 2) # 输出通道改为1
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
self.label_emb = nn.Embedding(num_classes, latent_dim)
def forward(self, z, labels):
label_emb = self.label_emb(labels)
z = torch.cat([z, label_emb], dim=1)
x = self.fc(z)
x = x.view(x.size(0), 256, image_size // 16, image_size // 16)
x = self.relu(self.bn1(self.upconv1(x)))
x = self.relu(self.bn2(self.upconv2(x)))
x = self.relu(self.bn3(self.upconv3(x)))
x = self.tanh(self.upconv4(x))
return x
# VAE主类(保持不变)
class VAE(nn.Module):
def __init__(self, channels, latent_dim, num_classes):
super(VAE, self).__init__()
self.encoder = Encoder(channels, latent_dim, num_classes)
self.decoder = Decoder(channels, latent_dim, num_classes)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x, labels):
mu, logvar = self.encoder(x, labels)
z = self.reparameterize(mu, logvar)
recon_x = self.decoder(z, labels)
return recon_x, mu, logvar
# 初始化模型、优化器
num_classes = len(dataset.classes)
model = VAE(channels, latent_dim, num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 损失函数(保持不变)
def vae_loss(recon_x, x, mu, logvar):
recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum')
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss, kld_loss, recon_loss + kld_loss
# 训练循环(保持不变)
with tqdm(total=epochs, desc="Total Training Progress") as pbar_total:
for epoch in range(epochs):
model.train()
epoch_recon_loss = 0.0
epoch_kld_loss = 0.0
epoch_total_loss = 0.0
batch_pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
for batch_idx, (images, labels) in enumerate(batch_pbar):
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
recon_images, mu, logvar = model(images, labels)
recon_loss, kld_loss, total_loss = vae_loss(recon_images, images, mu, logvar)
total_loss.backward()
optimizer.step()
epoch_recon_loss += recon_loss.item()
epoch_kld_loss += kld_loss.item()
epoch_total_loss += total_loss.item()
batch_pbar.set_postfix({
"Recon Loss": f"{recon_loss.item():.2f}",
"KLD Loss": f"{kld_loss.item():.2f}",
"Total Loss": f"{total_loss.item():.2f}"
})
batch_pbar.close()
avg_recon = epoch_recon_loss / len(dataloader.dataset)
avg_kld = epoch_kld_loss / len(dataloader.dataset)
avg_total = epoch_total_loss / len(dataloader.dataset)
pbar_total.set_postfix({
"Avg Recon Loss": f"{avg_recon:.2f}",
"Avg KLD Loss": f"{avg_kld:.2f}",
"Avg Total Loss": f"{avg_total:.2f}"
})
pbar_total.update(1)
# 样本生成与保存(优化为直接保存灰度图)
model.eval()
output_dir = r'I:\codes\vae(8.14)4'
num_samples_per_class = 700
classes = dataset.classes
num_classes = len(classes)
# 创建类别文件夹
for class_name in classes:
class_dir = os.path.join(output_dir, class_name)
os.makedirs(class_dir, exist_ok=True)
with torch.no_grad():
for class_idx in tqdm(range(num_classes), desc="Generating samples for all classes"):
class_name = classes[class_idx]
class_dir = os.path.join(output_dir, class_name)
generated_count = 0
class_pbar = tqdm(total=num_samples_per_class, desc=f"Generating samples for {class_name}", leave=False)
while generated_count < num_samples_per_class:
batch_size_gen = min(num_samples_per_class - generated_count, batch_size)
z = torch.randn(batch_size_gen, latent_dim).to(device)
labels = torch.full((batch_size_gen,), class_idx, dtype=torch.long).to(device)
samples = model.decoder(z, labels) # (batch_size_gen, 1, 64, 64)
# 反归一化并保存
samples = samples * 0.5 + 0.5 # 将[-1,1]转为[0,1]
for i in range(samples.size(0)):
sample = samples[i] # (1, 64, 64)
save_path = os.path.join(class_dir, f'generated_{generated_count}.png')
save_image(sample, save_path) # 直接保存单通道灰度图
generated_count += 1
class_pbar.update(1)
class_pbar.close()
print("灰度样本生成并保存完成。") 优化一下此模型,使得生成的样本质量更高,细节和训练图像更一致,生成完整代码