import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms, models
import matplotlib.pyplot as plt
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
import argparse
# 数据预处理和增强
class GarbageDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.transform = transform
self.classes = sorted(os.listdir('D:/rengongznshijian/train/data'))
self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
self.images = []
self.labels = []
# 加载图像和标签
for class_name in self.classes:
class_dir = os.path.join(data_dir, class_name)
if os.path.isdir(class_dir):
for img_name in os.listdir(class_dir):
if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
self.images.append(os.path.join(class_dir, img_name))
self.labels.append(self.class_to_idx[class_name])
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
label = self.labels[idx]
# 加载图像
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
# 定义残差网络模型
class ResNetGarbageClassifier(nn.Module):
def __init__(self, num_classes, model_name='resnet50', pretrained=True):
super(ResNetGarbageClassifier, self).__init__()
if model_name == 'resnet18':
self.backbone = models.resnet18(pretrained=pretrained)
elif model_name == 'resnet34':
self.backbone = models.resnet34(pretrained=pretrained)
elif model_name == 'resnet50':
self.backbone = models.resnet50(pretrained=pretrained)
elif model_name == 'resnet101':
self.backbone = models.resnet101(pretrained=pretrained)
else:
raise ValueError(f"Unsupported model: {model_name}")
# 获取最后一层的输入特征数
num_features = self.backbone.fc.in_features
# 替换最后的全连接层
self.backbone.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(num_features, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def forward(self, x):
return self.backbone(x)
# 训练函数
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device):
train_losses = []
train_accs = []
val_losses = []
val_accs = []
best_acc = 0.0
best_model_wts = None
for epoch in range(num_epochs):
print(f'Epoch {epoch+1}/{num_epochs}')
print('-' * 50)
# 训练阶段
model.train()
running_loss = 0.0
running_corrects = 0
train_bar = tqdm(train_loader, desc='Training')
for inputs, labels in train_bar:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
train_bar.set_postfix({
'Loss': f'{loss.item():.4f}',
'Acc': f'{torch.sum(preds == labels.data).double() / inputs.size(0):.4f}'
})
epoch_loss = running_loss / len(train_loader.dataset)
epoch_acc = running_corrects.double() / len(train_loader.dataset)
train_losses.append(epoch_loss)
train_accs.append(epoch_acc.cpu())
print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
# 验证阶段
model.eval()
val_running_loss = 0.0
val_running_corrects = 0
with torch.no_grad():
val_bar = tqdm(val_loader, desc='Validation')
for inputs, labels in val_bar:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
val_running_loss += loss.item() * inputs.size(0)
val_running_corrects += torch.sum(preds == labels.data)
val_bar.set_postfix({
'Loss': f'{loss.item():.4f}',
'Acc': f'{torch.sum(preds == labels.data).double() / inputs.size(0):.4f}'
})
val_epoch_loss = val_running_loss / len(val_loader.dataset)
val_epoch_acc = val_running_corrects.double() / len(val_loader.dataset)
val_losses.append(val_epoch_loss)
val_accs.append(val_epoch_acc.cpu())
print(f'Val Loss: {val_epoch_loss:.4f} Acc: {val_epoch_acc:.4f}')
# 学习率调度
if scheduler:
scheduler.step()
# 保存最佳模型
if val_epoch_acc > best_acc:
best_acc = val_epoch_acc
best_model_wts = model.state_dict().copy()
torch.save(best_model_wts, 'best_garbage_model.pth')
print(f'New best model saved with accuracy: {best_acc:.4f}')
print()
# 加载最佳模型权重
model.load_state_dict(best_model_wts)
return model, {
'train_losses': train_losses,
'train_accs': train_accs,
'val_losses': val_losses,
'val_accs': val_accs
}
# 可视化训练过程
def plot_training_history(history):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# 损失曲线
ax1.plot(history['train_losses'], label='Train Loss')
ax1.plot(history['val_losses'], label='Val Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
# 准确率曲线
ax2.plot(history['train_accs'], label='Train Acc')
ax2.plot(history['val_accs'], label='Val Acc')
ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.legend()
plt.tight_layout()
plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
plt.show()
# 评估模型
def evaluate_model(model, test_loader, device, class_names):
model.eval()
correct = 0
total = 0
class_correct = list(0. for _ in range(len(class_names)))
class_total = list(0. for _ in range(len(class_names)))
with torch.no_grad():
for inputs, labels in tqdm(test_loader, desc='Evaluating'):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
c = (predicted == labels).squeeze()
for i in range(labels.size(0)):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
print(f'Overall Accuracy: {100 * correct / total:.2f}%')
print('\nPer-class Accuracy:')
for i in range(len(class_names)):
if class_total[i] > 0:
print(f'{class_names[i]}: {100 * class_correct[i] / class_total[i]:.2f}%')
return 100 * correct / total
def main():
parser = argparse.ArgumentParser(description='Garbage Classification Training')
parser.add_argument('--data_dir', type=str, default='garbage_dataset', help='Dataset directory')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--epochs', type=int, default=50, help='Number of epochs')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
parser.add_argument('--model', type=str, default='resnet50', help='Model name')
parser.add_argument('--train_ratio', type=float, default=0.8, help='Training set ratio')
parser.add_argument('--val_ratio', type=float, default=0.1, help='Validation set ratio')
args = parser.parse_args()
# 配置参数
config = {
'data_dir': args.data_dir,
'batch_size': args.batch_size,
'num_epochs': args.epochs,
'learning_rate': args.lr,
'model_name': args.model,
'pretrained': True,
'num_workers': 4,
'train_ratio': args.train_ratio,
'val_ratio': args.val_ratio,
'test_ratio': 1 - args.train_ratio - args.val_ratio
}
# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
# 数据预处理
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载数据集
full_dataset = GarbageDataset(config['data_dir'], transform=train_transform)
# 获取类别信息
class_names = full_dataset.classes
num_classes = len(class_names)
print(f'Number of classes: {num_classes}')
print(f'Class names: {class_names}')
# 分割数据集
dataset_size = len(full_dataset)
train_size = int(config['train_ratio'] * dataset_size)
val_size = int(config['val_ratio'] * dataset_size)
test_size = dataset_size - train_size - val_size
# 检查每个数据集的样本数是否为 0
train_dataset, val_dataset, test_dataset = random_split(
full_dataset, [train_size, val_size, test_size]
)
# 验证和测试集使用验证时的transform
val_dataset.dataset.transform = val_transform
test_dataset.dataset.transform = val_transform
# 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=config['batch_size'],
shuffle=True,
num_workers=config['num_workers']
)
val_loader = DataLoader(
val_dataset,
batch_size=config['batch_size'],
shuffle=False,
num_workers=config['num_workers']
)
test_loader = DataLoader(
test_dataset,
batch_size=config['batch_size'],
shuffle=False,
num_workers=config['num_workers']
)
print(f'Train samples: {len(train_dataset)}')
print(f'Val samples: {len(val_dataset)}')
print(f'Test samples: {len(test_dataset)}')
# 创建模型
model = ResNetGarbageClassifier(
num_classes=num_classes,
model_name=config['model_name'],
pretrained=config['pretrained']
)
model = model.to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
# 训练模型
print('Starting training...')
model, history = train_model(
model, train_loader, val_loader, criterion,
optimizer, scheduler, config['num_epochs'], device
)
# 可视化训练过程
plot_training_history(history)
# 评估模型
print('Evaluating on test set...')
test_accuracy = evaluate_model(model, test_loader, device, class_names)
# 保存最终模型
torch.save({
'model_state_dict': model.state_dict(),
'class_names': class_names,
'config': config
}, 'final_garbage_model.pth')
print(f'Test accuracy: {test_accuracy:.2f}%')
print('Training completed!')
if __name__ == '__main__':
main()
最新发布