1. 背景:为什么要自定义网络、损失函数和模块?
在深度学习任务中,PyTorch 提供了许多现成的模块(如 torch.nn.Linear、torch.nn.Conv2d)和损失函数(如 torch.nn.CrossEntropyLoss、torch.nn.MSELoss),这些工具极大地方便了模型开发。但在以下场景中,你可能需要自定义:
- 研究创新:研究新的网络架构(如全新注意力机制、特殊正则化方式)或损失函数(如结合多种任务的复合损失)。
- 特定任务需求:现有模块无法满足任务需求,例如需要设计针对特定数据的激活函数或层。
- 深入理解:从头实现网络和损失函数有助于理解深度学习的底层原理,如梯度计算、反向传播等。
本文以手写数字识别任务(类似 MNIST 数据集)为例,逐步讲解如何从零构建自定义网络结构、激活函数、损失函数和优化过程。
2. 准备工作:PyTorch 基础和环境设置
在开始之前,确保已安装 PyTorch 和相关依赖。以下是环境设置和基础知识回顾:
2.1 安装 PyTorch
假设使用 Python 3.8+,可通过以下命令安装 PyTorch(以 CPU 版本为例,GPU 版本请参考官方文档):
pip install torch torchvision
2.2 PyTorch 核心概念
- Tensor:PyTorch 的基本数据结构,类似于 NumPy 的多维数组,支持 GPU 加速。
- Module:
torch.nn.Module是所有神经网络模块的基类,自定义层和网络需继承它。 - Autograd:PyTorch 的自动求导引擎,负责计算梯度和反向传播。
- Optimizer:如
torch.optim.SGD,用于更新模型参数。 - Dataset 和 DataLoader:用于加载和批量处理数据。
2.3 示例任务
我们将构建一个简单的全连接神经网络(MLP)解决 MNIST 手写数字分类问题。输入为 28×28 灰度图像(展平为 784 维向量),输出为 10 个类别的概率分布。我们将自定义:
- 网络结构(不使用
nn.Linear) - 激活函数(不使用
nn.ReLU) - 损失函数(不使用
nn.CrossEntropyLoss) - 训练循环和优化器(部分自定义)。
3. 自定义网络结构
PyTorch 的 nn.Module 是构建神经网络的核心。我们将从头实现一个简单的全连接层(类似 nn.Linear)和一个多层感知机(MLP)。
3.1 自定义全连接层
全连接层的数学形式为:
y=xWT+by = xW^T + by=xWT+b
其中:
- xxx:输入张量,形状为 [batch_size,input_dim][\text{batch\_size}, \text{input\_dim}][batch_size,input_dim]
- WWW:权重矩阵,形状为 [output_dim,input_dim][\text{output\_dim}, \text{input\_dim}][output_dim,input_dim]
- bbb:偏置向量,形状为 [output_dim][\text{output\_dim}][output_dim]
- yyy:输出张量,形状为 [batch_size, output_dim][\text{batch\_size, output\_dim}][batch_size, output_dim]
我们将创建一个自定义全连接层,包含权重和偏置的初始化,以及前向传播逻辑。
import torch
import torch.nn as nn
class CustomLinear(nn.Module):
def __init__(self, input_dim, output_dim):
super(CustomLinear, self).__init__()
# 初始化权重和偏置
self.input_dim = input_dim
self.output_dim = output_dim
# 使用 Kaiming 初始化权重
self.weight = nn.Parameter(torch.randn(output_dim, input_dim) * (2.0 / input_dim) ** 0.5)
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, x):
# 前向传播:y = xW^T + b
return torch.matmul(x, self.weight.t()) + self.bias
代码解析:
nn.Parameter:将权重和偏置标记为可学习参数,PyTorch 会自动跟踪其梯度。- Kaiming 初始化:权重初始化使用 He 初始化(适合 ReLU 激活函数),公式为 2input_dim\sqrt{\frac{2}{\text{input\_dim}}}input_dim2

最低0.47元/天 解锁文章
1050

被折叠的 条评论
为什么被折叠?



