第G4周:CGAN|生成手势图像 | 可控制生成


第G4周:CGAN|生成手势图像 | 可控制生成)

第G4周:CGAN|生成手势图像 | 可控制生成

一、前言

二、我的环境

  • 电脑系统:Windows 10
  • 语言环境:Python 3.8.5
  • 编译器:Spyder

三、代码实现

1、导入第三方库

# -*- coding:utf-8 -*-
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets,transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torchsummary import summary
import matplotlib.pyplot as plt
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

2、数据预处理

train_transform = transforms.Compose([
    transforms.Resize(int(128* 1.12)),   ## 图片放大1.12倍
    transforms.RandomCrop((128, 128)),     ## 随机裁剪成原来的大小
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])
batch_size = 64
train_dataset = datasets.ImageFolder(root='E:/BaiduNetdiskDownload/GAN-Data/rps', transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=6)

3、数据可视化

def show_images(dl):
    for images,_ in dl:
      fig,ax = plt.subplots(figsize=(10,10))
      ax.set_xticks([]);ax.set_yticks([])
      ax.imshow(make_grid(images.detach(),nrow=16).permute(1,2,0))
      break
 
show_images(train_loader)

在这里插入图片描述

四、定义模型

4.1 模型构建

latent_dim = 100
n_classes = 3
embedding_dim = 100

4.2 定义生成器

'''
定义生成器 Generator
'''
'''
定义生成器 Generator
'''
# 自定义权重初始化函数,用于初始化生成器和判别器的权重
def weights_init(m):
    # 获取当前层的类名
    classname = m.__class__.__name__
 
    # 如果当前层是卷积层(类名中包含 'Conv' )
    if classname.find('Conv') != -1:
        # 使用正态分布随机初始化权重,均值为0,标准差为0.02
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    
    # 如果当前层是批归一化层(类名中包含 'BatchNorm' )
    elif classname.find('BatchNorm') != -1:
        # 使用正态分布随机初始化权重,均值为1,标准差为0.02
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        # 将偏置项初始化为全零
        torch.nn.init.zeros_(m.bias)
 
 
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
 
        # 定义条件标签的生成器部分,用于将标签映射到嵌入空间中
        # n_classes:条件标签的总数
        # embedding_dim:嵌入空间的维度
        self
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值