GAN实战——生成手写字体

该代码示例展示了如何用PyTorch构建生成对抗网络(GAN)来训练MNIST数据集。生成器和判别器的网络结构被定义,随后进行训练,以生成类似手写数字的图像。训练过程包括损失计算、优化器更新以及损失跟踪。
部署运行你感兴趣的模型镜像
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import time

time_start = time.time()

# 生成器生成的数据在 [-1, 1]
transform = transforms.Compose([
    transforms.ToTensor(),  # 会做0-1归一化,也会channels, height, width
    transforms.Normalize((0.5,), (0.5,))
])

train_ds = torchvision.datasets.MNIST('data', train=True, transform=transform)
dataLoader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)


# 生成器网络定义
# 输入是长度为100的噪声(正态分布随机数)
# 输出为(1, 28, 28)的图片
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Tanh()
        )

    def forward(self, x):
        img = self.main(x)
        img = img.view(-1, 28, 28, 1)
        return img

# 判别器网络定义
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.main(x)
        return x


device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optimizer = torch.optim.Adam(dis.parameters(), lr=0.0001)
g_optimizer = torch.optim.Adam(gen.parameters(), lr=0.0001)

# 损失函数
loss_fn = torch.nn.BCELoss()

# 绘图函数
def gen_img_plot(model, test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i+1)
        plt.imshow((prediction[i] + 1)/2)
        plt.axis('off')
    plt.show()


test_input = torch.randn(16, 100, device=device)


# GAN训练
D_loss = []
G_loss = []


# 训练循环
for epoch in range(20):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataLoader)  # 返回批次数
    for step, (img, _) in enumerate(dataLoader):
        img = img.to(device)
        size = img.size(0)
        random_noise = torch.randn(size, 100, device=device)

        # 判别器的损失与优化
        d_optimizer.zero_grad()
        real_output = dis(img)  # 对判别器输入真实图片, real_output是对真实图片的判断结果
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 判别器在真实图像上的损失
        d_real_loss.backward()

        gen_img = gen(random_noise)
        fake_output = dis(gen_img.detach())  # 判别器输入生成的图片,fake_output对生成图片的预测
        d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))  # 判别器在生成图像上的损失
        d_fake_loss.backward()
        d_loss = d_real_loss + d_fake_loss
        d_optimizer.step()

        # 生成器的损失与优化
        g_optimizer.zero_grad()
        fake_output = dis(gen_img)
        g_loss = loss_fn(fake_output, torch.ones_like(fake_output))  # 生成器的损失
        g_loss.backward()
        g_optimizer.step()

        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss

    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        print("Epoch:", epoch)
        gen_img_plot(gen, test_input)

time_end = time.time()
print("花费总时间为:", time_end - time_start)

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

### 个性化手写字体生成中的GAN模型设计 生成对抗网络GAN)在个性化手写字体生成中展现出强大潜力,其核心思想是通过生成器与判别器的对抗训练,生成与目标风格高度相似的手写字体图像。在实际应用中,为了实现个性化风格控制,通常采用类别条件生成对抗网络(Conditional GAN, cGAN),该模型通过在生成器和判别器中引入类别标签作为输入,实现对生成字体风格的精确控制[^2]。 #### 模型结构与训练流程 在cGAN框架中,生成器的输入由随机噪声向量 $ z $ 和类别标签 $ y $ 组成,生成器的目标是根据 $ z $ 和 $ y $ 生成符合指定风格的手写字体图像。判别器则接收图像和类别标签作为输入,判断图像是否来自真实数据集。其损失函数定义为: $$ \min_{G} \max_{D} L_{cGAN}(G, D) = E_{x \sim P_{data}(x)}[\log D(x | y)] + E_{x \sim P(z)}[\log(1 - D(G(z | y)))] $$ 这种结构使得模型能够在训练过程中学习不同书写风格的特征,并在推理阶段根据输入的类别标签生成相应风格的字体[^4]。 #### 数据准备与预处理 生成个性化手写字体的前提是获取足够多的手写样本。可以通过扫描真实手写文档或使用公开数据集(如IAM Handwriting Database)构建训练数据集。每张图像需进行预处理,包括图像增强、二值化、归一化等操作,以提高模型训练的稳定性与生成质量[^1]。 #### 代码实现示例 以下是一个基于PyTorch实现的简单cGAN生成手写字体的代码框架: ```python import torch import torch.nn as nn # 定义生成器 class Generator(nn.Module): def __init__(self, latent_dim, num_classes, img_shape): super(Generator, self).__init__() self.label_emb = nn.Embedding(num_classes, num_classes) self.init_size = img_shape[1] // 4 self.l1 = nn.Sequential(nn.Linear(latent_dim + num_classes, 128 * self.init_size ** 2)) self.conv_blocks = nn.Sequential( nn.BatchNorm2d(128), nn.Upsample(scale_factor=2), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128, 0.8), nn.LeakyReLU(0.2, inplace=True), nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, img_shape[0], 3, stride=1, padding=1), nn.Tanh() ) def forward(self, z, labels): gen_input = torch.cat((self.label_emb(labels), z), -1) out = self.l1(gen_input) out = out.view(out.shape[0], 128, self.init_size, self.init_size) img = self.conv_blocks(out) return img # 定义判别器 class Discriminator(nn.Module): def __init__(self, img_shape, num_classes): super(Discriminator, self).__init__() self.label_embedding = nn.Embedding(num_classes, np.prod(img_shape)) self.model = nn.Sequential( nn.Conv2d(img_shape[0] + 1, 16, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25), nn.Conv2d(16, 32, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25), nn.Conv2d(32, 64, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25), nn.Conv2d(64, 128, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25), nn.Flatten(), nn.Linear(128 * (img_shape[1] // 16) ** 2, 1) ) def forward(self, img, labels): label_embedding = self.label_embedding(labels).view(img.shape[0], 1, img.shape[2], img.shape[3]) d_in = torch.cat((img, label_embedding), 1) validity = self.model(d_in) return validity ``` #### 模型优化与部署 训练过程中,应采用适当的优化策略,如使用Adam优化器、动态调整学习率、引入梯度惩罚(Wasserstein GAN with Gradient Penalty, WGAN-GP)等技术,以提升模型的收敛速度与生成质量。训练完成后,生成器可用于生成指定风格的手写字体图像,将其保存为TrueType(.ttf)或OpenType(.otf)字体文件,并嵌入到Word文档中进行字体换[^1]。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值