### 使用 PyTorch 构建卷积神经网络 (CNN)
构建卷积神经网络(CNN)涉及定义模型架构、准备数据集以及训练过程。下面是一个简单的例子来展示如何使用PyTorch创建并训练一个用于MNIST手写数字识别任务的CNN。
#### 定义 CNN 模型
首先导入必要的库,并定义一个继承自`nn.Module`类的新类,该类代表了整个网络结构:
```python
import torch
from torch import nn, optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
# 卷积层
self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
# 全连接层
self.fc1 = nn.Linear(32 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 32 * 7 * 7) # 展平张量
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
```
此代码片段展示了怎样通过组合多个卷积层(`Conv2d`)和最大池化操作(`MaxPool2d`)来设计特征提取器部分;接着利用两层全连接层完成分类工作[^3]。
#### 数据预处理与加载
为了能够顺利地输入到上述定义好的CNN中去,在实际应用之前还需要对原始图像做一定的变换处理:
```python
transform = transforms.Compose([
transforms.ToTensor(), # 将PIL Image转换成tensor形式
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
```
这里采用了标准的数据增强方法——即将图片转化为张量格式以便于后续计算[^1]。
#### 设置优化器与损失函数
接下来设置好用来更新权重参数的学习算法(即优化器),还有衡量预测结果好坏的标准(即损失函数)。对于多类别分类问题而言,交叉熵通常是个不错的选择:
```python
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
```
#### 开始训练循环
最后一步就是编写训练逻辑,迭代遍历所有批次的数据来进行前向传播、反向传播及参数调整:
```python
for epoch in range(num_epochs): # 对整个数据集重复num_epochs次
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/(i+1)}')
```
以上便是基于PyTorch框架搭建简单版CNN的整体流程概述。