GAN_conv 代码详解

本文详细介绍了PyTorch中的核心概念,包括张量操作如tensor.clamp,DataLoader的工作原理,Variable的属性及反向传播机制,以及GPU数据处理。还探讨了损失函数的反向传播(loss.backward())和优化器的参数更新(optimizer.step()),以及模型参数的保存与加载。此外,讲解了数据集处理类torch.utils.data.Dataset和DataLoader的作用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.tensor.clamp(a,min,max)

#将tensor类型的a中的小于min的元素设为min,大于max的元素设为max

2. DataLoader(minst, batch_size, shuffle=True, num_workers)

#shuffle=True在每个epoch中打乱数据再取batch
#每次dataloader加载数据时:dataloader一次性创建num_worker个工作进程,
#worker也是普通的工作进程
#并用batch_sampler将指定batch分配给指定worker,worker将它负责的batch加载进RAM。
#然后,dataloader从RAM中找本轮迭代要用的batch,如果找到了,就使用。如果没找到,
#就要num_worker个worker继续加载batch到内存,直到dataloader在RAM中找到目标batch。
#一般情况下都是能找到的,因为batch_sampler指定batch时当然优先指定本轮要用的batch。

3.torch.autograd.Variable

autograd//自动微分

3.1 Variable中包含三个属性:data,grad,grad_fn

data:tensor形式的原始数据,通过.data来访问原始raw data (tensor)

grad:通过调用.backward()来计算梯度, , 并将变量梯度累加到.grad

grad_fn:指向Function对象,用于反向传播的梯度计算之用

 

 3.2 Variable的属性

每个Variable都有两个属性,requires_grad和volatile, 这两个属性都可以将子图从梯度计算中排除并可以增加运算效率

3.2.1 requires_grad:排除特定子图,不参与反向传播的计算,即不会累加记录grad

3.2.2 volatile: 推理模式, 计算图中只要有一个子图设置为True, 所有子图都会被设置不参与反向传

播计算,.backward()被禁止

 3.2.3 loss.backward()和optimizer.step()

loss.backward() //将Variable中的grad中计算出

optimizer.step() //将grad按照更新规则加到data上

4.CPU和GPU的数据转移

.cuda() //把数据转移到GPU上

.data() //读取Variable中的tensor

.cpu() //把数据转移到CPU上

5.Class  torch.utils.data.DataLoader(datasetbatch_size=1shuffle=Falsesampler=Nonebatch_sampler=Nonenum_workers=0)

将Dataset和sampler组合起来,提供一个迭代器遍历给定数据集

6. torch.utils.data.Dataset()

7.nn.LeakyReLU(negative_slope=0.01inplace=False)

#inplace设为True后,输出数据会覆盖原来的数据

8.torch.save(G.state_dict(),  'path')

8.1 torch.save()

保存模型参数,可以方便调用继续训练,保存的文件一般为后缀.pth

8.2 Module.state_dict() 和Module.parameter()

Module.state_dict() 和Module.parameter()中都存储了网络的参数

Module.state_dict().keys()返回字典的键,键为网络所有参数的名称

Module.state_dict()返回了一个包含网络参数的列表

Module.parameter() 其类型为class 'generator'

当调用Module.parameter()后返回了一个生成器(generator),可以发现Module.parameter()和Module.state_dict()存储了一样的数据,不过有区别的是,Module.parameter()中存放的Tensor的requires_grad设置为了True。

 

 

### CGAN MNIST 代码实现与解析 #### 条件生成对抗网络简介 条件生成对抗网络(CGAN)是在标准GAN的基础上改进而来的模型。通过向生成器和判别器中加入额外的条件信息,使得生成过程更加可控[^1]。 #### 数据准备 为了训练CGAN来生成MNIST手写数字图像,首先需要加载并预处理MNIST数据集: ```python from keras.datasets import mnist import numpy as np def load_data(): (X_train, y_train), (_, _) = mnist.load_data() # 归一化到[-1, 1] X_train = (X_train.astype(np.float32) - 127.5) / 127.5 # 添加通道维度 X_train = np.expand_dims(X_train, axis=3) return X_train, y_train ``` 此函数会返回标准化后的图像数组以及对应的类别标签[^3]。 #### 构建生成器 生成器接收噪声向量z和目标类别的one-hot编码作为输入,并输出伪造的手写数字图像: ```python from keras.layers import Input, Dense, Reshape, Flatten, Embedding, multiply, Dropout from keras.models import Model def build_generator(latent_dim): noise_shape = (latent_dim,) label_shape = (10,) # 对于MNIST而言共有十个类别 noise = Input(shape=noise_shape) label = Input(shape=label_shape) model = Sequential([ Dense(256 * 7 * 7, activation="relu", input_dim=int(np.prod(noise_shape))), Reshape((7, 7, 256)), UpSampling2D(), Conv2DTranspose(filters=128,kernel_size=(5,5),strides=(1,1),padding='same',activation='relu'), BatchNormalization(momentum=0.8), UpSampling2D(), Conv2DTranspose(filters=64,kernel_size=(5,5),strides=(1,1),padding='same',activation='relu'), BatchNormalization(momentum=0.8), Conv2DTranspose(filters=1,kernel_size=(5,5),strides=(1,1),padding='same',activation='tanh') ]) label_embedding = Flatten()(Embedding(input_dim=10,output_dim=10)(label)) model_input = Multiply()([noise,label_embedding]) img = model(model_input) generator = Model(inputs=[noise,label], outputs=img,name="generator") return generator ``` 这段代码定义了一个基于卷积层的生成器架构,它能够根据给定的随机噪音和指定的类别生成相应的手写数字图像。 #### 定义判别器 判别器负责区分真实的MNIST样本与由生成器产生的假样本: ```python from keras.layers.convolutional import Conv2D, MaxPooling2D def build_discriminator(img_shape): img = Input(shape=img_shape) label = Input(shape=(10,), dtype='float32') model = Sequential([ Conv2D(16,(3,3), strides=(2,2), padding='same',input_shape=img_shape), LeakyReLU(alpha=0.2), Dropout(0.25), Conv2D(32,(3,3), strides=(2,2), padding='same'), ZeroPadding2D(padding=((0,1),(0,1))), LeakyReLU(alpha=0.2), Dropout(0.25), Flatten(), Dense(units=1, activation='sigmoid')]) d_in = Concatenate(axis=-1)([Flatten()(img), label]) validity = model(d_in) discriminator = Model(inputs=[img,label],outputs=validity,name="discriminator") return discriminator ``` 这里构建了一个简单的二分类器用于判断传入图像是真是伪的同时考虑到了其所属类别。 #### 训练流程概述 整个训练过程中交替更新两个子网参数直至收敛;每次迭代时先固定住生成器权重优化判别器性能,再冻结判别器调整生成器以欺骗前者[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值