tf.keras绘制网络结构
tf.keras.utils.plot_model()绘制深度学习结构图
- keras官方加载常见深度学习模型链接:https://keras.io/zh/applications/
- keras官方绘制模型结构图链接:https://keras.io/zh/visualization/
import tensorflow as tf
import pydot
import os
from tensorflow.keras.applications.inception_v3 import InceptionV3
model = InceptionV3(input_shape = (300,300, 3),
include_top = False,
weights = None)
os.environ["PATH"] += os.pathsep + r'C:\Program Files\Graphviz\bin'
model_name = 'InceptionV3'
def plot_model(model,model_name):
tf.keras.utils.plot_model(
model, # 实例化的模型
to_file=f'./{model_name}.png', # 保存到的路径
show_shapes=True, # 是否显示shape变化
show_layer_names=True, # 是否显示名称
rankdir='TB',
expand_nested=True,
dpi = 1000
)
if not os.path.exists(f'./{model_name}.png'):
plot_model(model, model_name)
print(f'绘制{model_name}模型成功!!')