基于PyTorch和ResNet18的花卉识别实战(附完整代码)

一、项目背景与效果

花卉分类是计算机视觉的经典任务。本文使用PyTorch框架,基于ResNet18模型实现了102种花卉的分类任务。完整代码可直接复制运行,文中同步分析性能瓶颈与优化方案。

二、环境配置与数据准备

1. 环境要求

# 主要依赖库
import torch
from torch import nn, optim
from torchvision import transforms, datasets, models
import matplotlib.pyplot as plt
import numpy as np
import json

2. 数据集结构

data/
├── train/         # 训练集(6489张)
│   ├── 1/        # 类别编号文件夹
│   ├── 2/
│   └── ... 
├── valid/         # 验证集(1700张)
└── flower_names.json  # 类别映射文件

3. 类别映射文件

# flower_names.json 示例
{
  "1": "玫瑰",
  "2": "郁金香",
  ...,
  "102": "墙藓"
}

三、完整代码实现

1. 数据加载与增强

# 定义数据路径
data_dir = 'D:/python_text/python/花卉识别/data'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

# 数据增强与预处理
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize([96, 96]),
        transforms.RandomRotation(45),
        transforms.CenterCrop(64),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.RandomGrayscale(p=0.025),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize([64, 64]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# 创建数据集
batch_size = 512
train_dataset = datasets.ImageFolder(train_dir, transform=data_transforms['train'])
valid_dataset = datasets.ImageFolder(valid_dir, transform=data_transforms['valid'])

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

2. 模型定义(修改ResNet18)

class ResNet18Model(nn.Module):
    def __init__(self, num_classes=102, pretrained=True):
        super().__init__()
        self.base_model = models.resnet18(pretrained=pretrained)
        num_features = self.base_model.fc.in_features
        self.base_model.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        return self.base_model(x)

# 初始化模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ResNet18Model().to(device)

3. 训练配置

# 定义损失函数与优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 需显式设置学习率
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# 加载类别名称
with open('D:/python_text/python/花卉识别/data/flower_names.json', 'r') as f:
    flower_names = json.load(f)

4. 训练与验证循环

num_epochs = 25
for epoch in range(num_epochs):
    # 训练阶段
    model.train()
    train_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * images.size(0)
    
    # 验证阶段
    model.eval()
    valid_loss = 0.0
    correct = 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            valid_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
    
    # 打印统计信息
    train_loss = train_loss / len(train_dataset)
    valid_loss = valid_loss / len(valid_dataset)
    valid_acc = correct / len(valid_dataset)
    
    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'Train Loss: {train_loss:.4f} | Val Loss: {valid_loss:.4f} | Val Acc: {valid_acc:.4f}')

四、优化方案

1. 代码改进建议

# 优化点示例:减小批次大小 + 学习率预热
batch_size = 256  # 原512
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)  # 添加L2正则

# 添加学习率预热
from torch.optim.lr_scheduler import LinearLR
warmup_scheduler = LinearLR(optimizer, start_factor=0.01, total_iters=5)

2. 其他优化方向

  • 数据层面

    • 使用公开数据集(如Oxford-102 Flowers)扩充数据

    • 添加针对性数据增强(花瓣局部裁剪、光照模拟)

  • 模型层面

    • 更换为ResNet50/ViT模型

    • 添加注意力机制模块

  • 训练技巧

    • 使用标签平滑(Label Smoothing)

    • 引入Focal Loss解决类别不平衡

五、总结

本文提供了可直接运行的花卉分类完整代码,并针对低准确率问题提出了改进方向。关键点:

  1. 数据增强需符合花卉特征(避免过度旋转)

  2. 合理设置超参数(批次大小、学习率)

  3. 复杂场景建议使用更先进的模型架构

代码可直接复制到本地,修改数据集路径后运行。建议尝试添加Grad-CAM可视化模块,深入分析模型决策依据。

### GNSS 中多普勒、伪距载波相位的概念及应用 #### 1. 多普勒效应及其在GNSS中的意义 在卫星导航系统中,载波多普勒指的是由于相对运动引起的接收到的载波频率的变化。这种变化能够反映用户相对于卫星的速度。具体而言,当用户设备与卫星之间存在相对运动时,接收端检测到的载波频率会发生偏移,该现象被称为载波多普勒效应[^1]。 #### 2. 伪距定义及其重要性 伪距是指从地面站到空间飞行器之间的几何路径长度加上各种误差成分的结果。它本质上是通过测量信号传输时间并乘以光速来估算的距离值。然而,实际操作过程中,这个数值包含了多种因素造成的偏差,比如大气延迟、钟差等。尽管如此,在没有其他更精准数据的情况下,伪距仍然是确定位置的关键参数之一。 #### 3. 载波相位测量的特点 相比于基于C/A码或P(Y)码测定的粗略距离——即所谓的“伪距”,载波相位提供了更为精细的位置信息。这是因为载波波长远小于扩频码周期,从而使得其对应的测距精度更高。不过,使用这种方法面临的主要挑战在于如何解决整周模糊度问题,也就是不知道确切有多少完整的波长存在于两地间。一旦解决了这个问题,就能显著提高定位准确性[^2]。 #### 4. 组合技术:相位平滑伪距的应用 为了克服单一方法存在的局限性,工程师们开发出了结合两者优点的技术方案—相位平滑伪距。这项技术充分利用了载波相位较高的分辨率以及伪距易于获取的优势,经过适当处理后可以获得更加可靠且准确的位置估计。特别是对于动态环境下的快速收敛平稳跟踪具有重要意义[^3]。 ```python def phase_smoothed_pseudorange(pseudo_range, carrier_phase): """ 计算相位平滑伪距 参数: pseudo_range (float): 初始伪距测量值 carrier_phase (float): 同步时间段内的累积载波相位变化 返回: float: 平滑后的伪距 """ smoothed_value = pseudo_range + carrier_phase / wavelength return smoothed_value ```
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值