1.保存和加载模型
1.保存参数
torch.save(model_object.state_dict(), ‘params.pkl’)
model_object.load_state_dict(torch.load(‘params.pkl’))
model_object:为模型的实例化
加载时,要定义该类和实例
2.保存模型
torch.save(model_object, ‘model.pkl’)
model = torch.load(‘model.pkl’)
model_object:为模型的实例化
加载时,model即为可用模型
2.torch.max(test_output, 1)
torch.max(test_output, 1)
输出格式[tensor([最大值]),tensor([最大值的位置])]
参数1/0,输出行/列最大
3.直接调用数据集送入模型
mages=Variable(test_dataset.test_data[:100].reshape(-1, 28*28).float())
.test_data[:100]:test_dataset中标签为test_data的数据
reshape:改变数据类型
.float():改变数据byte为float型
4.torch.max(test_output, 1)[1].data.numpy().squeeze()
torch.max(test_output, 1)[1].data.numpy().squeeze()
data.numpy():数据从variable转为tensor(data) 再转为numpy(numpy)
5.view()
改变输入信号的维度,image.view(-1,28*28),-1:根据原始尺寸和列,固定行
view和reshape的区别:reshape可以处理非连续区域的张量
6.Pytorch出现 raise NotImplementedError
经过反复检查,发现是 forward函数 出了问题,没检测到forward函数
查看修改forward函数