Github项目推荐|可视化 GAN 的训练过程

部署运行你感兴趣的模型镜像

点击上方“算法猿的成长“,关注公众号,选择加“星标“或“置顶”

总第 137 篇文章,本文大约 1000 字,阅读大约需要 5 分钟

今天介绍的一个开源的 github 项目,主要是实现了对 GAN 训练过程的可视化代码,项目链接如下:

https://github.com/EvgenyKashin/gan-vis

或者点击文章底部“阅读原文”,直接跳转。

接下来是简单介绍这份代码的情况,基本实现的功能,效果等。

前言

这是一个简单实现了学习和可视化 2d 的 GANs 的实验代码。在训练了数十个小时的 StyleGAN 后,现在可以通过快速的迭代(30s 左右)来直观的可视化一些超参数的情况(但是并不确定这种直观情况是否可以适用于更大的 GAN 模型)。主要是受到 https://poloclub.github.io/ganlab/ 的启发,但可能有人更希望在 Colab 中运行代码。

可视化

对训练的动态过程的可视化包括了:

  • 真实数据的分布情况(黑色的点)

  • 输入固定的噪音,然后由 G 网络生成的假数据;

  • D 网络对整个输入空间的决策边界,以及用不同颜色代表其输出的概率(红色代表判定为真实数据的可能性很高,蓝色则是低)

  • 绿色箭头表示每个生成的数据点,最大化 D 网络输出的方向

可视化结果

接下来是展示可视化的一些效果:

没有采用 batch-norm 的 G 和 D 的训练情况
加入 batch-norm 的 G 和 D 的训练情况
评价标准的可视化

第一行是训练的过程(输入是固定的噪音)以及多种评判标准(G 和 D 的梯度归一化,losses 以及 D 对真假数据的输出)。第二行展示了输入噪音以及 G 网络中间层的激活函数(映射为 2 维)

对输入噪音的 G 网络的转换

可调试的选项

  • 输入数据的分布

  • batch 大小,训练的 epochs

  • D 和 G 的学习率(可能是最重要的)

  • D 和 G 的优化器

  • 输入噪音的分布

  • 神经元的数量,激活函数

  • 损失函数(BCE,L2)

  • 权重初始化

  • 正则化(batch-norm,dropout,权重衰减)

采用的是 CPU,因为对可视化的实验已经满足速度的要求。

未来的工作

  • 增加更多的损失函数

  • 增加更多的正则化技术

项目代码可以直接访问 github 查看,或者关注我的公众号--【算法猿的成长】,在后台回复“play_gans",获取代码。


精选AI文章

1.  2020年计算机视觉学习指南

2. 是选择Keras还是PyTorch开始你的深度学习之旅呢?

3. 编写高效的PyTorch代码技巧(上)

4. 编写高效的PyTorch代码技巧(下)

5. 深度学习算法简要综述(上)

6. 深度学习算法简要综述(下)

7. 10个实用的机器学习建议

8. 实战|手把手教你训练一个基于Keras的多标签图像分类器

精选python文章

1.  python数据模型

2. python版代码整洁之道

3. 快速入门 Jupyter notebook

4. Jupyter 进阶教程

5. 10个高效的pandas技巧

精选教程资源文章

1. [资源分享] TensorFlow 官方中文版教程来了

2. [资源]推荐一些Python书籍和教程,入门和进阶的都有!

3. [Github项目推荐] 推荐三个助你更好利用Github的工具

4. Github上的各大高校资料以及国外公开课视频

5. GitHub上有哪些比较好的计算机视觉/机器视觉的项目?

欢迎关注我的微信公众号--算法猿的成长,或者扫描下方的二维码,大家一起交流,学习和进步!

 

如果觉得不错,在看、转发就是对小编的一个支持!

您可能感兴趣的与本文相关的镜像

GPT-SoVITS

GPT-SoVITS

AI应用

GPT-SoVITS 是一个开源的文本到语音(TTS)和语音转换模型,它结合了 GPT 的生成能力和 SoVITS 的语音转换技术。该项目以其强大的声音克隆能力而闻名,仅需少量语音样本(如5秒)即可实现高质量的即时语音合成,也可通过更长的音频(如1分钟)进行微调以获得更逼真的效果

### CycleGAN 可视化方法 CycleGAN 是一种用于无监督图像到图像转换的强大工具,其可视化过程可以帮助理解模型的行为以及评估训练效果。以下是几种常见的可视化方法及其对应的代码示例。 #### 1. **生成样本的对比** 通过比较原始图像与其经过 CycleGAN 转换后的结果,可以直观地观察模型的效果。通常会展示 `real_A` 到 `fake_B` 的映射,以及 `real_B` 到 `fake_A` 的逆向映射。 ```python import torch from torchvision.utils import make_grid import matplotlib.pyplot as plt def visualize_images(real_A, fake_B, real_B, fake_A): """显示真实图像与生成图像""" fig, axes = plt.subplots(2, 2, figsize=(8, 8)) # 将张量转为numpy数组并调整范围至[0, 1] def tensor_to_image(tensor): image = (tensor.detach().cpu() * 0.5 + 0.5).clamp(0, 1) return image.permute(1, 2, 0).numpy() images = [ ("Real A", real_A), ("Fake B", fake_B), ("Real B", real_B), ("Fake A", fake_A) ] for ax, (title, img_tensor) in zip(axes.flatten(), images): ax.imshow(tensor_to_image(img_tensor.squeeze())) ax.set_title(title) ax.axis('off') plt.tight_layout() plt.show() # 假设已经加载了模型和数据 visualize_images(real_A, fake_B, real_B, fake_A) ``` 此部分展示了如何将生成的结果与原图进行对比[^1]。 --- #### 2. **损失曲线绘制** 为了监控训练过程中不同类型的损失变化情况(如身份损失、循环一致性损失和对抗损失),可以通过记录每轮迭代中的损失值来绘制曲线。 ```python import numpy as np import matplotlib.pyplot as plt def plot_loss_curves(losses_dict, epochs): """绘制各种损失随epoch的变化趋势""" plt.figure(figsize=(10, 6)) for loss_name, values in losses_dict.items(): plt.plot(np.arange(len(values)), values, label=loss_name) plt.xlabel("Epochs") plt.ylabel("Loss Value") plt.title("Training Loss Curves Over Epochs") plt.legend(loc="upper right") plt.grid(True) plt.show() # 示例字典结构 { 'D_A': [...], 'D_B': [...], 'G_ABA': [...] } plot_loss_curves({ "Generator G": gen_losses, "Discriminator D_A": disc_a_losses, "Discriminator D_B": disc_b_losses }, num_epochs) ``` 上述代码片段能够帮助分析模型收敛性和稳定性[^3]。 --- #### 3. **中间特征层激活热力图** 对于更深入的理解,还可以提取生成器内部某些卷积层的输出,并将其可视化成热力图形式。这有助于研究特定区域是如何被处理的。 ```python from torchvision.models.feature_extraction import create_feature_extractor def get_activation_heatmap(model, input_img, layer_name='model.7'): feature_extractor = create_feature_extractor( model.G_AB.eval(), return_nodes=[layer_name] ) activations = feature_extractor(input_img)[layer_name] heatmap = activations.mean(dim=1).squeeze().detach().cpu().numpy() normalized_heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) return normalized_heatmap def show_heatmap_on_image(image, heatmap): from PIL import Image import cv2 img_np = ((image.cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5) * 255).astype(np.uint8) resized_heatmap = cv2.resize(heatmap, dsize=img_np.shape[:2][::-1]) colored_map = cv2.applyColorMap((resized_heatmap*255).astype(np.uint8), colormap=cv2.COLORMAP_JET) overlayed = cv2.addWeighted(colored_map[:, :, ::-1], 0.4, img_np, 0.6, 0.) plt.imshow(Image.fromarray(overlayed.astype(np.uint8))) plt.axis('off'); plt.show() heatmp = get_activation_heatmap(Gs['G_AB'], real_A.unsqueeze(0)) show_heatmap_on_image(real_A.squeeze(), heatmp) ``` 这段脚本允许查看指定层次上的响应模式[^2]。 --- #### 4. **远程服务器环境下的 Jupyter Notebook 使用** 如果是在云端 GPU 上运行实验,则推荐利用 Jupyter Lab 来完成交互式的开发调试工作流。例如,在矩池云上部署好 PyTorch 环境之后: ```bash cd mnt/ git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.git pip install visdom tqdm dominate scikit-image pillow==4.1.1 future imageio requests html webcolors termcolor pyyaml opencv-python jupyter lab & ``` 随后可通过浏览器访问 notebook 并执行前述任意一段绘图逻辑[^4]。 --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

spearhead_cai

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值