网络结构可视化的重要性
模型结构可视化是一个非常重要的工具,特别当你在:
- 实现论文中的结构,对他并不熟悉
- 实现自己自定义的网络结构
通过检查输出的图相,你可以知道网络设计的逻辑上是否有瑕疵,包括:
- 网络中的层序不正确
- 卷积层或池化层后的输出维度不正确
所以我们建议在每个卷积块和池化层块之后将模型可视化,让自己能够进行验证。
结构可视化实现
-
首先我们需要安装一个库:
graphviz
mac具体的安装步骤可以参考我的另一篇博客:Mac上安装graphviz. 成功解决ImportError: Failed to import pydot
Ubuntu:
sudo apt-get install graphviz
-
接下来,我们还需要安装两个python包
pip install graphviz==0.5.2
pip install graphviz==1.0.0
- 创建
visualize_architecture.py
并插入如下代码:
from nn.conv.lenet import LeNet
from tensorflow.keras.utils import plot_model
# 初始化LeNet并将网络结构可视化图像写入硬盘
model = LeNet.build(28, 28, 1, 10)
plot_model(model, to_file="lenet.png", show_shapes=True)
运行结果: