PyTorch 是一个强大的深度学习框架,广泛应用于研究和生产环境。以下是一个完整的 PyTorch 教程,涵盖从基础到高级的内容,并提供实例代码。
1. PyTorch 基础
1.1 安装 PyTorch
pip install torch torchvision
1.2 张量(Tensors)
张量是 PyTorch 的核心数据结构,类似于 NumPy 的数组。
import torch
# 创建张量
x = torch.tensor([1.0, 2.0, 3.0])
print(x)
# 基本操作
y = torch.tensor([4.0, 5.0, 6.0])
print(x + y) # 加法
print(torch.dot(x, y)) # 点积
# 张量形状
matrix = torch.tensor([[1, 2], [3, 4]])
print(matrix.shape) # 输出: torch.Size([2, 2])
1.3 自动求导(Autograd)
PyTorch 的 autograd
模块支持自动微分。
# 创建需要梯度的张量
x = torch.tensor([2.0], requires_grad=True)
# 定义函数
y = x ** 2 + 3 * x + 1
# 计算梯度
y.backward()
print(x.grad) # 输出: tensor([7.]) (dy/dx = 2x + 3)
2. 神经网络基础
2.1 定义神经网络
使用 torch.nn
模块定义神经网络。
import torch.nn as nn
import torch.nn.functional as F
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(10, 50) # 输入 10 维,输出 50 维
self.fc2 = nn.Linear(50, 1) # 输入 50 维,输出 1 维
def forward(self, x):
x = F.relu(self.fc1(x)) # 激活函数
x = self.fc2(x)
return x
# 实例化网络
model = SimpleNet()
print(model)
2.2 训练神经网络
使用 torch.optim
模块进行优化。
import torch.optim as optim
# 定义损失函数和优化器
criterion = nn.MSELoss() # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 模拟数据
inputs = torch.randn(100, 10) # 100 个样本,每个样本 10 维
targets = torch.randn(100, 1) # 100 个目标值
# 训练循环
for epoch in range(100):
optimizer.zero_grad() # 清空梯度
outputs = model(inputs) # 前向传播
loss = criterion(outputs, targets) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch + 1}/100], Loss: {loss.item():.4f}')
3. 数据加载与预处理
3.1 使用 DataLoader
DataLoader
用于批量加载数据。
from torch.utils.data import DataLoader, TensorDataset
# 创建数据集
dataset = TensorDataset(inputs, targets)
# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
# 遍历 DataLoader
for batch_inputs, batch_targets in dataloader:
print(batch_inputs.shape, batch_targets.shape)
3.2 图像数据预处理
使用 torchvision
处理图像数据。
from torchvision import datasets, transforms
# 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量
transforms.Normalize((0.5,), (0.5,)) # 归一化
])
# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# 查看数据
images, labels = next(iter(train_loader))
print(images.shape) # 输出: torch.Size([32, 1, 28, 28])
4. 卷积神经网络(CNN)
4.1 定义 CNN
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.fc1 = nn.Linear(32 * 14 * 14, 10) # MNIST 输出 10 类
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # 卷积 -> 激活 -> 池化
x = x.view(-1, 32 * 14 * 14) # 展平
x = self.fc1(x)
return x
# 实例化 CNN
model = CNN()
print(model)
4.2 训练 CNN
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练循环
for epoch in range(5):
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch [{epoch + 1}/5], Loss: {loss.item():.4f}')
5. 模型保存与加载
5.1 保存模型
torch.save(model.state_dict(), 'model.pth')
5.2 加载模型
model = CNN()
model.load_state_dict(torch.load('model.pth'))
model.eval() # 设置为评估模式
6. 高级主题
6.1 使用 GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
inputs = inputs.to(device)
targets = targets.to(device)
6.2 自定义数据集
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
# 使用自定义数据集
dataset = CustomDataset(inputs, targets)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
6.3 学习率调度
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
for epoch in range(100):
# 训练代码
scheduler.step() # 更新学习率
总结
以上是 PyTorch 的完整教程,涵盖了从基础到高级的内容,并提供了实例代码。通过这些内容,你可以快速掌握 PyTorch 的核心功能,并应用于实际项目中。