编写PyTorch模型训练函数涉及多个关键步骤,包括模型定义、设备分配、损失函数和优化器的初始化、数据加载、训练循环的构建以及模型性能的评估。以下是完整的实现方法和注意事项。
### 1. 定义训练函数的基本结构
训练函数通常包括以下几个部分:
- 设备分配(CPU/GPU)
- 模型初始化
- 数据加载器(DataLoader)
- 优化器和损失函数的定义
- 训练循环(epoch和batch级别的迭代)
- 模型保存和性能评估
```python
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
def train_model(model, dataloader: DataLoader, num_epochs=10, learning_rate=0.001, device='cuda' if torch.cuda.is_available() else 'cpu'):
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for inputs, labels in dataloader:
inputs, labels = inputs.to(device), labels.to(device)
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
epoch_loss = running_loss / len(dataloader.dataset)
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
return model
```
### 2. 数据加载与预处理
使用`torch.utils.data.DataLoader`来加载数据,支持批量加载和数据增强。可以结合`torchvision.transforms`进行数据预处理:
```python
from torchvision import transforms
from torchvision.datasets import CIFAR10
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
```
### 3. 模型定义
定义一个简单的卷积神经网络(CNN)作为示例:
```python
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 8 * 8, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 32 * 8 * 8)
x = self.fc1(x)
return x
```
### 4. 模型训练与保存
调用训练函数并传入模型和数据加载器:
```python
model = SimpleCNN()
trained_model = train_model(model, train_loader, num_epochs=10, learning_rate=0.001)
# 保存训练好的模型
torch.save(trained_model.state_dict(), 'trained_model.pth')
```
### 5. 模型评估
在训练完成后,可以通过测试集评估模型性能:
```python
def evaluate_model(model, test_loader, device='cuda' if torch.cuda.is_available() else 'cpu'):
model.to(device)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Test Accuracy: {100 * correct / total:.2f}%')
```
### 6. 完整代码整合
将上述代码整合为完整的训练流程:
```python
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.datasets import CIFAR10
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 8 * 8, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 32 * 8 * 8)
x = self.fc1(x)
return x
def train_model(model, dataloader: DataLoader, num_epochs=10, learning_rate=0.001, device='cuda' if torch.cuda.is_available() else 'cpu'):
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for inputs, labels in dataloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
epoch_loss = running_loss / len(dataloader.dataset)
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
return model
def evaluate_model(model, test_loader, device='cuda' if torch.cuda.is_available() else 'cpu'):
model.to(device)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Test Accuracy: {100 * correct / total:.2f}%')
# 数据预处理和加载
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 模型训练
model = SimpleCNN()
trained_model = train_model(model, train_loader, num_epochs=10, learning_rate=0.001)
# 保存模型
torch.save(trained_model.state_dict(), 'trained_model.pth')
# 模型评估
evaluate_model(trained_model, test_loader)
```
### 7. 模型微调(Fine-tuning)
如果使用预训练模型,可以通过加载预训练权重并冻结部分层进行微调[^2]:
```python
import torchvision.models as models
# 加载预训练的ResNet-50模型
model = models.resnet50(pretrained=True)
# 冻结所有参数
for param in model.parameters():
param.requires_grad = False
# 替换最后一层
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10) # 假设新任务有10个类别
```
###