import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from collections import defaultdict
# 设置中文字体(仅用于绘图说明,matplotlib 支持有限)
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
# -----------------------------
# 1. GPU 检查
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
if device.type == 'cuda':
print(f"GPU 名称: {torch.cuda.get_device_name(0)}")
# -----------------------------
# 2. 数据预处理 + 增强
# -----------------------------
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = torchvision.datasets.CIFAR10(root='./datasets', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./datasets', train=False, download=True, transform=transform_test)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
# 类别名称
classes = ('飞机', '汽车', '鸟', '猫', '鹿', '狗', '青蛙', '船', '卡车', '马')
# -----------------------------
# 3. Residual Block 定义
# -----------------------------
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample # 调整维度
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
# -----------------------------
# 4. 构建 ResNet-18
# -----------------------------
def make_layer(block, in_channels, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or in_channels != out_channels:
downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels),
)
layers = []
layers.append(block(in_channels, out_channels, stride, downsample))
for _ in range(1, blocks):
layers.append(block(out_channels, out_channels, stride=1))
return nn.Sequential(*layers)
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.layer1 = make_layer(block, 64, 64, layers[0], stride=1)
self.layer2 = make_layer(block, 64, 128, layers[1], stride=2)
self.layer3 = make_layer(block, 128, 256, layers[2], stride=2)
self.layer4 = make_layer(block, 256, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# 创建模型
model = ResNet(ResidualBlock, [2, 2, 2, 2]).to(device) # ResNet-18
# -----------------------------
# 5. 损失函数、优化器、学习率调度
# -----------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
# -----------------------------
# 6. 训练日志记录
# -----------------------------
metrics = defaultdict(list)
best_val_loss = float('inf')
patience_counter = 0
patience_limit = 7
lr_list = []
# -----------------------------
# 7. 训练与验证函数
# -----------------------------
def train_epoch(model, dataloader, criterion, optimizer, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in dataloader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
acc = 100. * correct / total
return running_loss / len(dataloader), acc
def validate_epoch(model, dataloader, criterion, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in dataloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
acc = 100. * correct / total
return running_loss / len(dataloader), acc
# -----------------------------
# 8. 开始训练
# -----------------------------
print("开始训练...")
num_epochs = 50
for epoch in range(num_epochs):
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
val_loss, val_acc = validate_epoch(model, test_loader, criterion, device)
# 记录指标
metrics['train_loss'].append(train_loss)
metrics['train_acc'].append(train_acc)
metrics['val_loss'].append(val_loss)
metrics['val_acc'].append(val_acc)
lr_list.append(optimizer.param_groups[0]['lr'])
print(f"Epoch [{epoch+1}/50] "
f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}% | "
f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%")
# 学习率调度
scheduler.step(val_loss)
# 早停判断
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
torch.save(model.state_dict(), 'best_resnet18_cifar10.pth')
print("✅ 模型已保存")
else:
patience_counter += 1
if patience_counter >= patience_limit:
print("🛑 早停触发")
break
# 加载最佳模型
model.load_state_dict(torch.load('best_resnet18_cifar10.pth'))
# -----------------------------
# 9. 绘制训练曲线
# -----------------------------
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.plot(metrics['train_acc'], label='训练准确率')
plt.plot(metrics['val_acc'], label='验证准确率')
plt.title('模型准确率')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.subplot(1, 3, 2)
plt.plot(metrics['train_loss'], label='训练损失')
plt.plot(metrics['val_loss'], label='验证损失')
plt.title('模型损失')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 3, 3)
plt.plot(lr_list, label='学习率', color='purple')
plt.title('学习率变化')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.yscale('log')
plt.legend()
plt.tight_layout()
plt.show()
# -----------------------------
# 10. 测试集预测 & 混淆矩阵
# -----------------------------
model.eval()
y_true = []
y_pred = []
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
y_true.extend(labels.cpu().numpy())
y_pred.extend(preds.cpu().numpy())
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
plt.title('混淆矩阵')
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.show()
# -----------------------------
# 11. 错误分类样本可视化
# -----------------------------
error_indices = np.where(np.array(y_pred) != np.array(y_true))[0]
print(f"共 {len(error_indices)} 个样本被错误分类")
num_show = min(12, len(error_indices))
fig, axes = plt.subplots(2, 6, figsize=(15, 6))
for i in range(num_show):
idx = error_indices[i]
img, label = test_dataset[idx // 128][0], y_true[idx]
img_np = img.permute(1, 2, 0).cpu().numpy()
img_np = (img_np * np.array([0.2023, 0.1994, 0.2010])) + np.array([0.4914, 0.4822, 0.4465]) # 反归一化
img_np = np.clip(img_np, 0, 1)
ax = axes[i // 6, i % 6]
ax.imshow(img_np)
ax.set_title(f'真:{classes[label]}, 预:{classes[y_pred[idx]]]}')
ax.axis('off')
plt.suptitle("错误分类的样本示例")
plt.tight_layout()
plt.show()
# -----------------------------
# 12. 分类报告
# -----------------------------
print("分类报告:")
print(classification_report(y_true, y_pred, target_names=classes))