1. 安装graphviz
brew install graphviz
我安装过程中碰到了 这个错误:
Directory not empty @ dir_s_rmdir - /usr/local/opt/openssl
需要卸载再重装 openssl:
sudo rm -rf /usr/local/opt/openssl
brew install openssl
2. pip install graphviz
3. 使用graphviz 可视化一个mlp
import torch
from torch import nn
from torchviz import make_dot
from torch.autograd import Variable
model = nn.Sequential()
model.add_module('W0', nn.Linear(8, 16))
model.add_module('tanh', nn.Tanh())
model.add_module('W1', nn.Linear(16, 1))
x = Variable(torch.randn(1,8))
vis_graph = make_dot(model(x), params=dict(model.named_parameters()))
vis_graph.view()
结果: