DenseNet(Densely Connected Convolutional Networks)是一种用于图像分类的深度学习模型,由Gao Huang等人在2017年提出。DenseNet通过密集连接的方式,使得每一层都直接连接到所有后续层,从而增强了特征传播和重用,减少了参数数量,并提高了模型的性能。
下面是一个使用PyTorch实现DenseNet进行图像分类的示例代码。
首先,确保你已经安装了必要的库:
pip install torch torchvision
注意:具体需要依据cuda版本来选择对应版本
2. 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torchvision.transforms import ToTensor
from torch.optim.lr_scheduler import StepLR
3. 定义DenseNet模型
PyTorch的`torchvision.models`模块中已经提供了预定义的DenseNet模型。我们可以直接使用这些模型,或者根据需要对其进行微调。
# 使用预定义的DenseNet121模型
model = models.densenet121(pretrained=True)
# 修改最后的全连接层以适应你的分类任务
num_classes = 10 # 假设你有10个类别
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
# 将模型移动到GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
4. 定义数据加载器
我们需要定义数据加载器来加载训练和测试数据。这里我们使用CIFAR-10数据集作为示例。
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)), # DenseNet输入尺寸为224x224
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
5. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 学习率调度器
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
```
### 6. 训练模型
```python
num_epochs = 10
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for i, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.to(device), labels.to(device)
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99: # 每100个batch打印一次损失
print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')
running_loss = 0.0
# 更新学习率
scheduler.step()
# 每个epoch结束后在测试集上评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Epoch [{epoch + 1}/{num_epochs}], Test Accuracy: {100 * correct / total:.2f}%')
6. 保存模型
训练完成后,你可以将模型保存到磁盘以便后续使用。
torch.save(model.state_dict(), 'densenet_cifar10.pth')
7. 加载模型并进行推理
如果你需要加载已经训练好的模型进行推理,可以使用以下代码:
model.load_state_dict(torch.load('densenet_cifar10.pth'))
model.eval()
# 进行推理
with torch.no_grad():
inputs = test_dataset[0][0].unsqueeze(0).to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
print(f'Predicted: {predicted.item()}')
以上展示了如何使用PyTorch实现DenseNet进行图像分类任务。实际实现时可以根据需要调整模型结构、数据集和超参数来适应不同的任务。DenseNet由于其密集连接的特性,通常在图像分类任务中表现出色。