廖星宇 多项式拟合
UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
将pytorch更新到0.4.0最新版后对0.3.1版本代码会有如下警告,它在提醒用户下个版本这将成为一个错误
修改:
# print_loss = loss.data[0]
print_loss = loss.item()
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 14 16:19:59 2018
@author: lthpc
"""
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable
from torch import nn,optim
def make_features(x):
'''Builds features i.e. a matrix with columns [x, x^2, x^3].'''
x = x.unsqueeze(1)
return torch.cat([x ** i for i in range(1,4)],1)
W_target = torch.FloatTensor([0.5,3,2.4]).unsqueeze(1)
b_target = torch.FloatTensor([0.9])
def f(x):
'''Approximated function.'''
return x.mm(W_target)+b_target[0]
def get_batch(batch_size=32):
'''Builds a batch i.e. (x,f(x)) pair.'''
random = torch.randn(batch_size)
x = make_features(random)
y = f(x)
if torch.cuda.is_available():
return Variable(x).cuda(), Variable(y).cuda()
else:
return Variable(x), Variable(y)
# Define model
class poly_model(nn.Module):
def __init__(self):
super(poly_model,self).__init__()
self.poly = nn.Linear(3,1)
def forward(self,x):
out = self.poly(x)
return out
if torch.cuda.is_available():
model = poly_model().cuda()
else:
model = poly_model()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),lr=1e-3)
epoch = 0
while True:
#Get data
batch_x,batch_y = get_batch()
# Forward pass
output = model(batch_x)
loss = criterion(output,batch_y)
# print_loss = loss.data[0]
print_loss = loss.item()
#Reset gradients
optimizer.zero_grad()
#Backward pass
loss.backward()
print('loss:%.5f\n'%loss)
#update parameters
optimizer.step()
epoch+=1
if print_loss<1e-3:
break
x = np.linspace(-1,1,30)
x_sample = torch.from_numpy(x)
x_sample = x_sample.unsqueeze(1)
x_sample = torch.cat([x_sample ** i for i in range(1,4)] , 1)
x_sample = x_sample.float()
y_actural = f(x_sample)
tt = x_sample.cuda()
y_predict = model(tt)
plt.plot(x,y_actural.numpy(),'ro',x,y_predict.data.cpu().numpy())
plt.legend(['real point','fit'])
plt.show()