本文主要是用PyTorch来实现一个简单的回归任务。
编辑器:spyder
1.引入相应的包及生成伪数据
import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable
# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())
# 变为Variable
x, y = Variable(x), Variable(y)
其中torch.linspace
是为了生成连续间断的数据,第一个参数表示起点,第二