nn.Conv2d的使用
此处我们仍然使用官网自带的数据集进行训练,最后将其可视化
加载数据集和可视化部分在此处不在介绍,若需要了解:
加载数据集:torch.utils.data中的DataLoader数据加载器(附代码)_硕大的蛋的博客-优快云博客
tensorboard可视化工具:Tensorboard 可视化工具的使用-史上最简单(附代码)_硕大的蛋的博客-优快云博客
导入相应的包和模块
import torch import torchvision import torch.nn as nn from torch.utils.data import DataLoader from tensorboardX import SummaryWriter
获取数据
dataset = torchvision.datasets.CIFAR10('../BigData', train=False, transform=torchvision.transforms.ToTensor(), download=True) dataloader = DataLoader(dataset, batch_size=64)
创建神经网络
class Gsw(nn.Module): def __init__(self): super(Gsw, self).__init__() self.con1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0) def forward(self, input): output = self.con1(input) return output
训练并将其可视化
gsw = Gsw() writer = SummaryWriter('LOGS/011log') for step, data in enumerate(dataloader): imgs, target = data output = gsw(imgs) writer.add_images('input', imgs, step) output = torch.reshape(output, (-1, 3, 30, 30)) writer.add_images('output', output, step)
完整代码
# 开发时间: 2021/11/21 17:51 import torch import torchvision import torch.nn as nn from torch.utils.data import DataLoader from tensorboardX import SummaryWriter dataset = torchvision.datasets.CIFAR10('../BigData', train=False, transform=torchvision.transforms.ToTensor(), download=True) dataloader = DataLoader(dataset, batch_size=64) class Gsw(nn.Module): def __init__(self): super(Gsw, self).__init__() self.con1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0) def forward(self, input): output = self.con1(input) return output gsw = Gsw() writer = SummaryWriter('LOGS/011log') for step, data in enumerate(dataloader): imgs, target = data output = gsw(imgs) writer.add_images('input', imgs, step) output = torch.reshape(output, (-1, 3, 30, 30)) writer.add_images('output', output, step)