可视化网络模型可以使用tensorboard和netron,这里netron对应复杂模型处理的比较好,tensorboard出不了图。
各阶段输入输出形状可以使用torchsummary(使用pip install torchsummary 安装)
一、 可视化网络模型
首先需要现成的网络结构或自定义的网络结构。这里使用现成的ResNet网络
import torch
import torch.nn as nn
import torch.nn.functional as F
class XX(nn.Module):
def __init__(self, in_channels=3,...):
super(XX, self).__init__()
pass
def forward(self, x):
return x
pass
安装好torchsummary之后,使用简单的几行代码(模型实例化,输入形状即可)
from torchsummary import summary
model = ResNet(depth=18)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
summary(model,(3,1300,800)) #(3,1300,800)输入形状
其中summary默认device是‘cuda’,可以改成以下代码,让模型和输入统一,(不会出现“RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same”的错误)
summary(model,(3,1300,800),device=device)
输入结果如下
二、 可视化网络模型
安装netron或tensorboard。 推荐使用netron
1. netron GitHub地址
通过以下代码进行可视化模型
import netron
# 实例化网络和输入形状
model = ResNET(depth=18)
input_sample = torch.randn(1, 3, 64, 64)
# 方式一 保存pth文件 并让netron读取
# model_path = 'cross.pth'
# torch.save(model.state_dict(),model_path)
# 方式二 导出onnx格式并读取
torch.onnx.export(model,input_sample,f='cross.onnx')
netron.start('cross.onnx')
用本地浏览器(推荐edge)打开代码运行提供的网址,即可看到网络模型
并且可以点击其中任意一模块查看输入输出等详细信息
2 tensoboard
model = ResNet(depth=18)
input_sample = torch.randn(1,3, 64, 64)
log_dir = 'runs/cnfe'
if os.path.exists(log_dir):
shutil.rmtree(log_dir)
# 创建一个TensorBoard记录器
writer = SummaryWriter(log_dir)
output = model(input_sample)
print("模型输出尺寸:", output[0].shape)
# 将模型的计算图添加到TensorBoard
writer.add_graph(model, input_sample)
print("计算图已成功写入 TensorBoard。")
# 关闭TensorBoard记录器
writer.close()
终端输入tensorboard --logdir=runs (路径一定要对,不然找不到)