Class1线性回归简洁实现代码

Class1线性回归简洁实现代码

# 导入numpy库,用于数值计算
import numpy as np
# 导入PyTorch库
import torch
# 从torch.utils模块中导入data子模块
from torch.utils import data
# 导入d2l
from d2l import torch as d2l
# 设置真实权重,大小为[2,-3.4]
true_w = torch.tensor([2,-3.4])
# 设置真实偏置,大小为4.2A
true_b = 4.2
# 定义输入特征和标签
features,labels = d2l.synthetic_data(true_w,true_b,1000)
# 构建PyTorch小批量数据加载器
# data_arrays:元组包含两个张量(features,labels)
# batch_size:批量大小
# is_train:是否打乱数据
def load_array(data_arrays,batch_size,is_train=True):
    """构造一个PyTorch数据迭代器"""
    # TensorDataset:PyTorch提供的简单数据封装类
    # 将features和labels组合成一个数据集,便于后续批量访问
    dataset = data.TensorDataset(*data_arrays)
    # DataLoader:数据加载器
    # dataset:数据源
    # batch_size:每次返回的样本数量
    # shuffle=is_train:训练模式,打乱数据
    return data.DataLoader(dataset,batch_size,shuffle=is_train)

# 设置批量大小
batch_size = 10
# 构造数据迭代器
data_iter = load_array((features,labels),batch_size)
# 获取一个迭代器
# next():取出第一个batch数据
next(iter(data_iter))
# 引入神经网络模块
from torch import nn
# 创建模型,包含一个线性层nn.Linear(2,1)
# Sequential:可以添加更多层
net = nn.Sequential(nn.Linear(2,1))
# 设置模型初始权重和偏置
# net[0]:获取nn.Linear(2,1)
# net[0].weight.data:线性层的权重张量
# normal_(0,0.01):用均值为0,标准差为0.01的正态分布进行初始化
net[0].weight.data.normal_(0,0.01)
# 偏置项初始化为0
net[0].bias.data.fill_(0)
# 使用L2范数计算均方误差
loss = nn.MSELoss()
# 实例化SGD
# net.paramters():模型可学习参数的迭代器,包括权重和偏置
trainer = torch.optim.SGD(net.parameters(),lr=0.03)
# 设置训练轮数
num_epochs = 3
# 定义3轮训练
for epoch in range(num_epochs):
    # 设置批次训练
    # X:[10,2],特征
    # y:[10,1],标签
    for X,y in data_iter:
        # 前向传播,计算损失
        # net(X):模型的预测值
        # loss():计算预测值和真实标签的误差
        l = loss(net(X),y)
        # 梯度清零
        trainer.zero_grad()
        # 反向传播计算梯度
        l.backward()
        # 更新参数,梯度下降
        trainer.step()
    # 计算均方误差
    l = loss(net(features),labels)
    # 每轮结束后打印整体损失,观察模型训练效果
    print(f'epoch {epoch+1},loss{l:f}')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值