在pytorch中,可以导入tensorboard模块,可视化网络结构及训练流程。
下面通过“CNN训练MNIST手写数字分类”的小例子来学习一些可视化工具的用法,只需要加少量代码。
一、tensorboardX的安装
pip install tensorboard
pip install tensorflow
pip install tensorboardX
二、导入tensorboardX
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#writer就相当于一个日志,保存你要做图的所有信息。第二句就是在你的项目目录下建立一个文件夹log,存放画图用的文件。刚开始的时候是空的
from tensorboardX import SummaryWriter
writer = SummaryWriter('log') #建立一个保存数据用的东西
三、搭建模型
#定义超参数
batch_size = 64
learning_rate = 1e-2
num_epoches = 20
#对数据进行预处理
data_tf = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])]
)
# 定义网络
class CNN(nn.Module):
def __init__(self):
super(CNN, se