小批量梯度下降法 (Mini-Batch Gradient Descent)

小批量梯度下降法 (Mini-Batch Gradient Descent, MBGD) 是梯度下降法的一种变体。与批量梯度下降法 (Batch Gradient Descent) 和随机梯度下降法 (Stochastic Gradient Descent, SGD) 不同,小批量梯度下降法在每次迭代中使用一个小批量 (mini-batch) 样本的数据来更新参数。这样既可以减少计算开销,又能够更稳定地更新参数。

原理

小批量梯度下降法的基本思想是通过在每次迭代中使用一部分数据样本 (mini-batch) 来计算梯度,从而更新参数。这种方法在计算效率和收敛速度之间取得了平衡。相比于 SGD,它能够更稳定地更新参数,避免一些噪声引起的波动;相比于批量梯度下降法,它减少了每次迭代的计算量,加快了更新速度。

核心公式

在第 t 次迭代中,选择一个小批量 Bt样本,并计算梯度:

其中:

  • θ(t) 为第 t 次迭代的参数
  • η 为学习率
  • J(θ; Bt) 为目标函数在小批量 Bt上的梯度

目标是最小化目标函数 J(θ)J(\theta)J(θ),即:

Python 示例

我们将使用 Scikit-Learn 中的线性回归模型进行训练,并绘制多个数据分析图形。具体的图形包括:损失函数的收敛图、预测结果与实际结果的比较图、以及参数更新的轨迹图。

数据生成和分割

首先,我们生成一个带有噪声的线性回归数据集,并将其划分为训练集和测试集。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.linear_model import SGDRegressor
from sklearn.metrics import mean_squared_error

# 生成回归数据集
X, y = make_regression(n_samples=1000, n_features=1, noise=20, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 打印前5条数据和对应的目标值
print("前5条训练数据:")
print(X_train[:5])
print("前5个训练目标值:")
print(y_train[:5])
小批量梯度下降法回归模型初始化

使用 SGDRegressor 初始化随机梯度下降回归模型,设定最大迭代次数、禁用容忍度检查、启用 warm_start,每次迭代使用上次的解,设定常数学习率和初始学习率。

# 初始化小批量梯度下降回归模型
sgd = SGDRegressor(max_iter=1, tol=None, warm_start=True, learning_rate='constant', eta0=0.01, random_state=42)
训练模型

进行多次迭代(epochs),在每个 epoch 进行一次模型训练,记录训练误差和测试误差,以及每次迭代后的系数值。

# 记录训练过程中的信息
n_epochs = 50
batch_size = 32
train_errors, test_errors = [], []
coef_updates = []

# 训练模型
for epoch in range(n_epochs):
    for i in range(0, X_train.shape[0], batch_size):
        X_batch = X_train[i:i+batch_size]
        y_batch = y_train[i:i+batch_size]
        sgd.fit(X_batch, y_batch)
    
    y_train_predict = sgd.predic
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值