1,pytorch 网络可视化——torchsummary
参考:参考1
示例:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
# 定义图像处理的ResNet模型
class ImageNet(nn.Module):
def __init__(self):
super(ImageNet, self).__init__()
# self.resnet = models.resnet18(pretrained=True)
self.conv1 = nn.Conv2d(128, 64, kernel_size=3, stride=2)
self.conv2 = nn.Conv2d(64, 32, kernel_size=3, stride=2)
self.conv3 = nn.Conv2d(32, 16, kernel_size=3, stride=2)
self.maxpool = nn.MaxPool2d(2,stride=2)
self.fc1 = nn.Linear(16*7*7,256)
# self.fc2 = nn.Linear()
# for param in self.resnet.parameters():
# param.requires_grad = False
# self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 256)
def forward(self, x):
# x = self.resnet(x)
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.maxpool(x))
x = x.contiguous().view(-1, 16*7*7)
x = self.fc1(x)
# x = self.fc2(x)
# x = self.fc(x)
return x
import torch.nn as nn
from torchsummary import summary
net = ImageNet()
#输出每层网络参数信息
summary(net, input_size=[(128, 128, 128), (1,)],batch_size=1,device="cpu")