MXNet深度学习框架中的逻辑回归详解

MXNet深度学习框架中的逻辑回归详解

mxnet mxnet 项目地址: https://gitcode.com/gh_mirrors/mx/mxnet

逻辑回归是深度学习中最基础且重要的模型之一,本文将通过MXNet的Gluon API详细介绍如何实现逻辑回归模型。我们将从数据准备、模型构建到训练评估完整走一遍流程,并重点讲解实现过程中的关键点和常见陷阱。

环境准备与数据生成

首先导入必要的MXNet模块:

import numpy as onp
import mxnet as mx
from mxnet import np, npx, autograd, gluon
from mxnet.gluon import nn, Trainer
from mxnet.gluon.data import DataLoader, ArrayDataset

mx.np.random.seed(12345)  # 设置随机种子保证结果可复现

我们生成一个模拟数据集,包含10个特征和一个二分类标签。标签生成采用非随机逻辑,使模型能够学习到有效模式:

def get_random_data(size, device):
    x = np.random.normal(0, 1, size=(size, 10), device=device)
    y = x.sum(axis=1) > 3  # 特征总和大于3的样本标记为1
    return x, y

设置基本超参数,这里使用CPU进行计算:

device = mx.cpu()
train_data_size = 1000  # 训练集大小
val_data_size = 100     # 验证集大小
batch_size = 10         # 批大小

数据加载与预处理

MXNet提供了Dataset和DataLoader类来处理数据。Dataset提供索引式数据访问,DataLoader负责数据的洗牌和批处理:

# 生成训练集和验证集
train_x, train_y = get_random_data(train_data_size, device)
train_dataset = ArrayDataset(train_x, train_y)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_x, val_y = get_random_data(val_data_size, device)
val_dataset = ArrayDataset(val_x, val_y)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

模型构建

逻辑回归模型的关键是输出层只需一个神经元。我们构建一个包含输入层(10神经元)、两个隐藏层(各10神经元)和输出层(1神经元)的网络:

net = nn.HybridSequential()
net.add(nn.Dense(units=10, activation='relu'))  # 输入层
net.add(nn.Dense(units=10, activation='relu'))  # 隐藏层1
net.add(nn.Dense(units=10, activation='relu'))  # 隐藏层2
net.add(nn.Dense(units=1))                     # 输出层:必须只有1个神经元

net.initialize(mx.init.Xavier())  # Xavier初始化

损失函数与优化器

对于二分类问题,我们使用SigmoidBinaryCrossEntropyLoss作为损失函数。优化器选择SGD(随机梯度下降):

loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
trainer = Trainer(params=net.collect_params(), optimizer='sgd',
                  optimizer_params={'learning_rate': 0.1})

评估指标使用准确率(Accuracy)和F1分数:

accuracy = mx.gluon.metric.Accuracy()
f1 = mx.gluon.metric.F1()

训练与验证过程

训练函数实现前向传播、损失计算和参数更新:

def train_model():
    cumulative_train_loss = 0
    for data, label in train_dataloader:
        with autograd.record():
            output = net(data)
            loss_result = loss(output, label)
        loss_result.backward()
        trainer.step(batch_size)
        cumulative_train_loss += np.sum(loss_result).item()
    return cumulative_train_loss

验证函数需要特别注意指标计算。由于Accuracy需要类别预测而非概率,我们需要对sigmoid输出进行阈值处理:

def validate_model(threshold=0.5):
    cumulative_val_loss = 0
    for val_data, val_y in val_dataloader:
        output = net(val_data)
        cumulative_val_loss += np.sum(loss(output, val_y)).item()
        
        # 将概率转换为类别
        prob = npx.sigmoid(output)
        pred_classes = mx.np.ceil(prob - threshold)
        
        accuracy.update(val_y, pred_classes.reshape(-1))
        
        # F1分数需要两类概率
        prob = prob.reshape(-1)
        probabilities = mx.np.stack([1 - prob, prob], axis=1)
        f1.update(val_y, probabilities)
    return cumulative_val_loss

训练循环

将上述组件组合成完整训练流程:

epochs = 10
for e in range(epochs):
    avg_train_loss = train_model() / train_data_size
    avg_val_loss = validate_model() / val_data_size
    
    print(f"Epoch: {e}, Train loss: {avg_train_loss:.2f}, "
          f"Val loss: {avg_val_loss:.2f}, "
          f"Accuracy: {accuracy.get()[1]:.2f}, "
          f"F1: {f1.get()[1]:.2f}")
    
    accuracy.reset()  # 重置指标

关键实现要点

  1. 输出层设计:尽管是二分类问题,输出层只需一个神经元,因为SigmoidBinaryCrossEntropyLoss只需要一个输入特征。

  2. 标签编码:必须将类别编码为0和1。如果原始数据使用其他编码(如-1和1),需要先进行转换。

  3. 损失函数选择:使用SigmoidBinaryCrossEntropyLoss而非普通交叉熵损失。

  4. 概率到类别的转换:在计算Accuracy前,需要将sigmoid输出的概率通过阈值(通常0.5)转换为类别。

总结

通过MXNet实现逻辑回归需要注意以上关键点。本文展示了完整的实现流程,包括数据准备、模型构建、训练验证和评估指标计算。逻辑回归虽然简单,但正确实现它对于理解更复杂的深度学习模型至关重要。

mxnet mxnet 项目地址: https://gitcode.com/gh_mirrors/mx/mxnet

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

樊会灿

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值