import torch #导入pytorch库
import matplotlib.pyplot as plt # 导入matplotlib(用于画图)库中pyplot模块,命名为plt
import random #导入random ,用于生成随机数
def create_data(w, b, data_num): #生成数据 w权重向量,b偏置项,data_num需要生成的数据点数量
x = torch.normal(0, 1, (data_num, len(w))) #生成符合正态分布的随机数 0正态分布的均值,1正态分布的标准差,(data_num, len(w))输出张量的形状(维度和大小)
y = torch.matmul(x, w) + b #matmul表示矩阵相乘
noise = torch.normal(0, 0.01, y.shape) #噪声要加到y上
y += noise
return x, y
num=500 #赋值
true_w = torch.tensor([8.1,2,2,4]) #创建张量
true_b = torch.tensor(1.1)
X, Y= create_data(true_w,true_b,num) #生成数据
plt.scatter(X[:,3],Y,1) #绘制散点图 1指定散点图中点的大小(默认20)
plt.show() #用于在屏幕上显示创建的图表
def data_provider(data,labal,batchsize): #定义函数,为机器学习模型提供数据 data数据集 labal对应数据集标签 batchsize每个批次大小
length=len(labal) #获取标签的长度
indices=list(range(length)) #创建索引列表
#我不能按顺序取 把数据打乱
random.shuffle(indices) #打乱列表顺序
for each in range(0,length,batchsize): #每次访问这个函数,就能提供一批数据
get_indices=indices[each:each+batchsize] #从打乱后的索引列表中提取当前批次的索引
get_data=data[get_indices] #获取该批次数据
get_label=label[get_indices] #获取该批次标签
yield get_data,get_label #返回当前批次的数据和标签 有存档点的return
batchsize=16 #每次处理的数据样本数量
# for batch_x,batch_y in data_provider(X,Y,batchsize): #遍历每个批次的数据和标签
# print(batch_x,batch_y) #打印
# break
def fun(x,w,b): #线性变换函数 x数据 w权重 b偏置
pred_y=torch.matmul(x,w)+b #矩阵相乘
return pred_y
def maeLoss(pre_y,y): #计算预测值和真实值之间的平均绝对误差
return torch.sum(abs(pre_y-y))/len(y)
def sgd(paras,lr): #随机梯度下降,更新参数
with torch.no_grad(): #上下文管理器,用于禁用梯度计算
for para in paras: #遍历参数列表
para -= para.grad*lr
para.grad.zero_() #清空参数梯度
lr=0.03 #设置学习率
w_0=torch.normal(0,0.01,true_w.shape,requires_grad=True) #这个w需要计算梯度
b_0=torch.tensor(0.01,requires_grad=True) #创建张量
print(w_0,b_0)
epochs=50 #设置训练的轮数
for epoch in range(epochs): #执行多次训练迭代
data_loss=0 #初始化变量,累积损失值
for batch_x,batch_y in data_provider(X,Y,batchsize):
pred_y=fun(batch_x,w_0,b_0)
loss=maeLoss(pred_y,batch_y)
loss.backward() #反向传播
sgd([w_0, b_0], lr)
data_loss += loss
print("epoch %03d:loss:%.6f"%(epoch,data_loss))
print("真实的函数值是",true_w,true_b)
print("训练得到的参数值是",w_0,b_0)
idx=3
plt.plot(X[:,idx].detach().numpy(),X[:,idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy()) #绘制二维图
plt.scatter(X[:,idx],Y,1)
plt.show()