GAN写程序记录

程序名:pytorch-GAN-timeseries

建议:据报道,在使用gan生成序列的若干工作中,循环鉴别器通常比卷积鉴别器更不稳定。因此,我推荐使用基于卷积的方法。
我没有对超参数和训练程序进行广泛的搜索,定性评估是唯一容易实现的。如果配置了一个目标任务(例如,学习一个策略),可以获得直观的和定量的评估,并用于选择最佳的模型。
在实际生成的性能和关于输入增量的错误之间有一点折衷。如果在delta上有轻微的精度对最终任务不是问题,它的误差可以忽略;如果想要尽可能地减少增量上的误差,可以增加监督目标的权重,或者使用监督微调。
这种训练有时容易出现模式崩溃:当前的实现可能受益于使用最近的GAN变体,如Wasserstein GAN。只要改变训练中对抗的部分就足够了。
在gan中注入条件的标准方法不能不考虑由增量产生的问题:正如我在一些独立实验中观察到的那样,基于因果卷积的神经网络能够很容易地解决检测给定序列的增量的问题。因此,接收增量作为输入的鉴别器可以很容易地区分具有正确增量的实序列和具有不正确增量的假序列。

WGAN相较于GAN的改进

作者把它命名为Critic以便与Discriminator作区分(注意:GAN之前的D网络都叫discriminator,但是由于这里不是做分类任务,WGAN作者觉得叫discriminator不太合适,于是将其叫为critic。)。两者的不同之处在于:
Critic最后一层抛弃了sigmoid,因为它输出的是一般意义上的分数,而不像Discriminator输出的是概率。
Critic的目标函数没有log项,这是从上面的推导得到的。
Critic在每次更新后都要把参数截断在某个范围,即weight clipping,这是为了保证上面讲到的Lipschitz限制。
Critic训练得越好,对Generator的提升更有利,因此可以放心地多训练Critic

WGAN的几大优越之处:
不再需要纠结如何平衡Generator和Discriminator的训练程度,大大提高了GAN训练的稳定性:Critic(Discriminator)训练得越好,对提升Generator就越有利。
即使网络结构设计得比较简陋,WGAN也能展现出良好的性能,包括避免了mode collapse的现象,体现了出色的鲁棒性。
Critic的loss很准确地反映了Generator生成样本的质量,因此可以作为展现GAN训练进度的定性指标。

  1. gan网络最终生成的数据与输入进去输入数据的batch_size 大小一样,最终将网络训练好,设置好fix_noise,通过调整batch_size的大小控制生成数量的多少
### 使用MATLAB编基于GAN进行缺失数据生成的程序 为了在MATLAB中实现基于GAN(生成对抗网络)的缺失数据生成,可以遵循以下结构化的方法。这种方法不仅利用了GAN的强大生成能力,还特别针对处理缺失数据进行了调整。 #### 1. 数据预处理 首先需要准备用于训练的数据集,并对其进行必要的预处理操作。这包括标准化输入特征以及标记哪些位置存在缺失值: ```matlab % 加载原始数据矩阵 X (NxD),其中 N 是样本数量, D 是特征维度 load('data.mat'); % 假设 data.mat 文件中含有名为 'X' 的变量 % 找到所有含有缺失值的位置 missingMask = isnan(X); % 将 NaN 替换为零或其他合适的填充值 X_filled = fillmissing(X,'constant',0); ``` #### 2. 构建生成器判别器模型 定义两个主要组件——生成器G判别器D。这里采用简单的全连接层架构作为示例;实际应用可根据具体需求选用更复杂的卷积/反卷积结构[^1]。 ```matlab generatorLayers = [ fullyConnectedLayer(128,'Name','fc1') reluLayer('Name','relu1') fullyConnectedLayer(size(X,2),'Name','fc2')]; discriminatorLayers = [ fullyConnectedLayer(128,'Name','fc1') leakyReluLayer(0.2,'Name','lrelu1') fullyConnectedLayer(1,'Name','fc2') sigmoidLayer()]; ``` #### 3. 定义损失函数与优化策略 设置适当的损失函数来衡量真假样本之间的差异,并指定Adam等常用梯度下降方法来进行参数更新。 ```matlab options = trainingOptions('adam',... 'MaxEpochs',50,... 'MiniBatchSize',64,... 'InitialLearnRate',0.0002,... 'GradientThreshold',10,... 'Shuffle','every-epoch',... 'Verbose',false,... 'Plots','training-progress'); ``` #### 4. 训练循环逻辑 构建自定义训练循环,在每次迭代过程中交替更新生成器判别器权重直至满足终止条件为止。 ```matlab for epoch=1:options.MaxEpochs % 获取当前批次的真实样本及其对应的二进制掩码 idx = randperm(numel(missingMask), options.MiniBatchSize); realSamples = X_filled(idx,:); masks = double(~missingMask(idx,:)); % 更新判别器... [lossD, gradD] = dlfeval(@computeDiscLoss, netD, ... cat(3,realSamples.*masks, genData .* masks), labelsTrue); % ...接着更新生成器 [~,gradG] = dlfeval(@computeGenLoss, netG, noiseInput, labelsFalse); end ``` 上述代码片段展示了如何通过`dlarray`对象传递动态图中的张量,并调用`dlfeval()`执行自动微分计算导数向量。注意这里的`netG`, `netD`分别代表初始化后的生成器判别器网络实例,而`genData`,`noiseInput`则表示由随机噪声驱动产生的假样本集合。 #### 5. 测试阶段:填补缺失部分 完成训练之后,就可以使用已经学得的知识去修复那些不完整的记录了! ```matlab function filledData = imputeMissingValues(netG, incompleteData) % 输入应为包含NaN项的时间序列数组 mask = ~isnan(incompleteData); z = randn([size(incompleteData)]); generatedPart = predict(netG,z).*double(mask); %#ok<PERM> filledData = incompleteData; filledData(isnan(filledData)) = generatedPart(isnan(filledData)); end ``` 此辅助功能接收任意形状的二维表单作为参数,返回经过插补处理的结果副本。内部实现了仅对未知区域施加变换的操作方式,确保已知数值保持不变。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值