第一步:准备数据
幼苗数据集,英文为 plant seedlings dataset。
https://www.kaggle.com/competitions/plant-seedlings-classification/data
train 目录下有 221 张 Common wheat 的图片。
共十二种植物的幼苗图片集。
- Black-grass
- Charlock
- Cleavers
- Common Chickweed
- Common wheat
- Fat Hen
- Loose Silky-bent
- Maize
- Scentless Mayweed
- Shepherds Purse
- Small-flowered Cranesbill
- Sugar beet
具体信息如下:
第二步:搭建模型
FasterNet是一种高效的神经网络架构,旨在提高计算速度而不牺牲准确性,特别是在视觉任务中。它通过一种称为部分卷积(PConv)的新技术来减少冗余计算和内存访问。这种方法使得FasterNet在多种设备上运行速度比其他网络快得多,同时在各种视觉任务中保持高准确率。例如,FasterNet在ImageNet-1k数据集上的表现超过了其他模型,如MobileViT-XXS,展现了其在速度和准确度方面的优势。
FasterNet的基本原理可以总结为以下几点:
1. 部分卷积(PConv): FasterNet引入了部分卷积(PConv),这是一种新型的卷积方法,它通过只处理输入通道的一部分来减少计算量和内存访问。
2. 加速神经网络: FasterNet利用PConv的优势,实现了在多种设备上比其他现有神经网络更快的运行速度,同时保持了较高的准确度。
下面为大家展示的是FasterNet的整体架构。
它包括四个层次化的阶段,每个阶段由一系列FasterNet块组成,并由嵌入或合并层开头。最后三层用于特征分类。在每个FasterNet块中,PConv层之后是两个点状卷积(PWConv)层。为了保持特征多样性并实现更低的延迟,仅在中间层之后放置了归一化和激活层。
第三步:训练代码
1)损失函数为:交叉熵损失函数
2)FasterNet代码:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from functools import partial
from typing import List
from torch import Tensor
import copy
import os
try:
from mmdet.models.builder import BACKBONES as det_BACKBONES
from mmdet.utils import get_root_logger
from mmcv.runner import _load_checkpoint
has_mmdet = True
except ImportError:
print("If for detection, please install mmdetection first")
has_mmdet = False
class Partial_conv3(nn.Module):
def __init__(self, dim, n_div, forward):
super().__init__()
self.dim_conv3 = dim // n_div
self.dim_untouched = dim - self.dim_conv3
self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)
if forward == 'slicing