PyTorch构建自定义神经网络

1. 背景:为什么要自定义网络、损失函数和模块?

在深度学习任务中,PyTorch 提供了许多现成的模块(如 torch.nn.Lineartorch.nn.Conv2d)和损失函数(如 torch.nn.CrossEntropyLosstorch.nn.MSELoss),这些工具极大地方便了模型开发。但在以下场景中,你可能需要自定义:

  • 研究创新:研究新的网络架构(如全新注意力机制、特殊正则化方式)或损失函数(如结合多种任务的复合损失)。
  • 特定任务需求:现有模块无法满足任务需求,例如需要设计针对特定数据的激活函数或层。
  • 深入理解:从头实现网络和损失函数有助于理解深度学习的底层原理,如梯度计算、反向传播等。

本文以手写数字识别任务(类似 MNIST 数据集)为例,逐步讲解如何从零构建自定义网络结构、激活函数、损失函数和优化过程。


2. 准备工作:PyTorch 基础和环境设置

在开始之前,确保已安装 PyTorch 和相关依赖。以下是环境设置和基础知识回顾:

2.1 安装 PyTorch

假设使用 Python 3.8+,可通过以下命令安装 PyTorch(以 CPU 版本为例,GPU 版本请参考官方文档):

pip install torch torchvision

2.2 PyTorch 核心概念

  • Tensor:PyTorch 的基本数据结构,类似于 NumPy 的多维数组,支持 GPU 加速。
  • Moduletorch.nn.Module 是所有神经网络模块的基类,自定义层和网络需继承它。
  • Autograd:PyTorch 的自动求导引擎,负责计算梯度和反向传播。
  • Optimizer:如 torch.optim.SGD,用于更新模型参数。
  • Dataset 和 DataLoader:用于加载和批量处理数据。

2.3 示例任务

我们将构建一个简单的全连接神经网络(MLP)解决 MNIST 手写数字分类问题。输入为 28×28 灰度图像(展平为 784 维向量),输出为 10 个类别的概率分布。我们将自定义:

  • 网络结构(不使用 nn.Linear
  • 激活函数(不使用 nn.ReLU
  • 损失函数(不使用 nn.CrossEntropyLoss
  • 训练循环和优化器(部分自定义)。

3. 自定义网络结构

PyTorch 的 nn.Module 是构建神经网络的核心。我们将从头实现一个简单的全连接层(类似 nn.Linear)和一个多层感知机(MLP)。

3.1 自定义全连接层

全连接层的数学形式为:
y=xWT+by = xW^T + by=xWT+b
其中:

  • xxx:输入张量,形状为 [batch_size,input_dim][\text{batch\_size}, \text{input\_dim}][batch_size,input_dim]
  • WWW:权重矩阵,形状为 [output_dim,input_dim][\text{output\_dim}, \text{input\_dim}][output_dim,input_dim]
  • bbb:偏置向量,形状为 [output_dim][\text{output\_dim}][output_dim]
  • yyy:输出张量,形状为 [batch_size, output_dim][\text{batch\_size, output\_dim}][batch_size, output_dim]

我们将创建一个自定义全连接层,包含权重和偏置的初始化,以及前向传播逻辑。

import torch
import torch.nn as nn

class CustomLinear(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(CustomLinear, self).__init__()
        # 初始化权重和偏置
        self.input_dim = input_dim
        self.output_dim = output_dim
        # 使用 Kaiming 初始化权重
        self.weight = nn.Parameter(torch.randn(output_dim, input_dim) * (2.0 / input_dim) ** 0.5)
        self.bias = nn.Parameter(torch.zeros(output_dim))
    
    def forward(self, x):
        # 前向传播:y = xW^T + b
        return torch.matmul(x, self.weight.t()) + self.bias

代码解析

  • nn.Parameter:将权重和偏置标记为可学习参数,PyTorch 会自动跟踪其梯度。
  • Kaiming 初始化:权重初始化使用 He 初始化(适合 ReLU 激活函数),公式为 2input_dim\sqrt{\frac{2}{\text{input\_dim}}}input_dim2
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

爱看烟花的码农

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值