用PyTorch训练一个猫狗分类器

数据集准备

使用Kaggle提供的猫狗分类数据集(Dogs vs Cats),包含25,000张图片(12,500张猫和12,500张狗)。数据集可从以下链接下载:

  • 官方地址:https://www.kaggle.com/c/dogs-vs-cats/data

解压后目录结构应如下:

data/
    train/
        cat.0.jpg
        dog.0.jpg
        ...
    test/
        test1.jpg
        test2.jpg
        ...

数据预处理

使用torchvision.transforms对图像进行标准化和增强:

from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

数据加载

使用ImageFolderDataLoader加载数据:

from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

train_dataset = ImageFolder('data/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = ImageFolder('data/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

模型定义

使用预训练的ResNet18模型并进行微调:

import torch.nn as nn
from torchvision.models import resnet18

model = resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)  # 替换全连接层

训练配置

设置损失函数和优化器:

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

训练循环

实现完整的训练过程:

for epoch in range(10):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')

模型评估

在测试集上评估模型性能:

correct = 0
total = 0
model.eval()
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total}%')

源码保存

将完整代码保存为cat_dog_classifier.py

# 完整代码包含上述所有片段
# 添加必要的import语句和主函数入口

模型保存与加载

训练完成后保存模型:

torch.save(model.state_dict(), 'cat_dog_classifier.pth')

加载已保存的模型:

model.load_state_dict(torch.load('cat_dog_classifier.pth'))

扩展建议

  1. 尝试使用不同的预训练模型(如VGG16、EfficientNet)
  2. 增加数据增强方法(如随机旋转、颜色抖动)
  3. 实现学习率调度器
  4. 添加早停机制防止过拟合
  5. 使用TensorBoard记录训练过程

完整项目代码可参考GitHub仓库: https://github.com/example/cat-dog-classifier-pytorch

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

shayudiandian

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值