Gluon教程:线性回归的简洁实现解析

Gluon教程:线性回归的简洁实现解析

d2l-zh d2l-zh 项目地址: https://gitcode.com/gh_mirrors/d2l/d2l-zh

引言

线性回归是机器学习中最基础也最重要的模型之一。在之前的教程中,我们学习了如何从零开始实现线性回归模型。本文将介绍如何利用深度学习框架Gluon来更简洁地实现线性回归,让开发者能够专注于模型设计而非底层实现细节。

数据准备

生成合成数据

我们首先使用d2l.synthetic_data()函数生成一个线性回归问题的合成数据集。这个函数会基于给定的真实权重true_w和偏置true_b生成特征和标签:

true_w = d2l.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

这里生成了1000个样本,每个样本有2个特征,标签由线性关系y = Xw + b + ϵ生成,其中ϵ是噪声项。

数据加载与批处理

现代深度学习框架都提供了高效的数据加载工具。在Gluon中,我们可以使用gluon.data.DataLoader来创建数据迭代器:

def load_array(data_arrays, batch_size, is_train=True):
    dataset = gluon.data.ArrayDataset(*data_arrays)
    return gluon.data.DataLoader(dataset, batch_size, shuffle=is_train)

batch_size = 10
data_iter = load_array((features, labels), batch_size)

这个数据迭代器会在每个epoch自动打乱数据(当is_train=True时),并按指定batch_size返回小批量数据。

模型定义

网络结构

在Gluon中,我们可以使用nn.Sequential()容器来快速构建模型。对于线性回归,只需要一个全连接层(Dense层):

from mxnet.gluon import nn
net = nn.Sequential()
net.add(nn.Dense(1))

这里有几个关键点需要注意:

  1. 我们不需要显式指定输入维度,Gluon会自动推断
  2. 输出维度设为1,因为我们进行的是标量回归
  3. Sequential容器允许我们按顺序堆叠多个层

参数初始化

模型参数需要合理初始化才能有效训练。Gluon提供了多种初始化方法:

from mxnet import init
net.initialize(init.Normal(sigma=0.01))

这里我们使用正态分布初始化权重(标准差0.01),偏置默认初始化为0。Gluon采用延迟初始化策略,只有在第一次前向传播时才会真正初始化参数。

训练过程

损失函数

对于线性回归问题,最常用的损失函数是均方误差(MSE):

loss = gluon.loss.L2Loss()

L2Loss计算预测值与真实值之间的平方差,适用于回归问题。

优化器

我们使用小批量随机梯度下降(SGD)作为优化算法:

from mxnet import gluon
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.03})

这里指定了:

  1. 要优化的参数(net.collect_params()收集所有可训练参数)
  2. 优化算法('sgd')
  3. 学习率(0.03)

训练循环

完整的训练过程包含以下几个步骤:

num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        with autograd.record():
            l = loss(net(X), y)  # 前向传播计算损失
        l.backward()  # 反向传播计算梯度
        trainer.step(batch_size)  # 更新参数
    l = loss(net(features), labels)  # 计算整个训练集的损失
    print(f'epoch {epoch + 1}, loss {l.mean().asnumpy():f}')

每个epoch中:

  1. 遍历数据迭代器获取小批量数据
  2. 前向传播计算损失
  3. 反向传播计算梯度
  4. 优化器更新参数
  5. 最后计算并打印整个训练集的平均损失

结果验证

训练完成后,我们可以检查模型学到的参数与真实参数的接近程度:

w = net[0].weight.data()
print(f'w的估计误差: {true_w - d2l.reshape(w, true_w.shape)}')
b = net[0].bias.data()
print(f'b的估计误差: {true_b - b}')

在理想情况下,估计误差应该很小,说明模型成功学习到了数据的生成规律。

关键知识点总结

  1. 数据加载:Gluon提供了DataLoader等工具简化数据加载和批处理
  2. 模型构建:使用Sequential容器可以快速构建顺序模型
  3. 参数初始化:框架提供了多种初始化方法,并支持延迟初始化
  4. 损失函数:内置了常见损失函数如L2Loss
  5. 优化器:封装了各种优化算法,简化参数更新过程
  6. 训练流程:标准流程包括前向传播、反向传播和参数更新

扩展思考

  1. 学习率调整:如果将损失改为小批量平均损失,学习率需要相应调整
  2. 替代损失函数:可以尝试Huber损失等鲁棒损失函数
  3. 梯度访问:可以通过autograd相关功能访问模型梯度

通过使用深度学习框架,我们能够用更简洁的代码实现相同的功能,同时获得更好的性能和可扩展性。这对于实现更复杂的模型尤为重要。

d2l-zh d2l-zh 项目地址: https://gitcode.com/gh_mirrors/d2l/d2l-zh

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

劳丽娓Fern

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值