VisionMamba图像分类实战

【图书推荐】《深入探索Mamba模型架构与应用》-优快云博客

本节完成使用多种ticket的VisionMamba图像分类实战,首先我们将完成含有位置表示的双向VisionMamba模型构建。在模型训练过程中,我们将使用适当的损失函数和优化器来最小化预测误差,并通过多次迭代来优化模型参数。为了防止过拟合,我们还会采用一些正则化技术,如Dropout和权重衰减。

训练完成后,我们将对Mamba模型进行评估,通过计算分类准确率、召回率等指标来衡量其性能。此外,我们还会使用混淆矩阵来可视化模型的分类结果,以便更直观地了解模型在各类别上的表现。

7.3.1  VisionMamba模型的构建

具体来看,对于VisionMamba模型的构建,我们可以通过直接堆叠前面加载了增强功能的多个模型来实现,并添加位置编码和双向计算模块,代码如下:

import position
import moudle
class VisionMamba(torch.nn.Module):
    def _ _init_ _ (self,img_size = 32,embed_dim = 768, patch_size=4,num_layers = 3,num_classes = 10,
                 if_bidirectional = True,if_rope = True,device = "cuda"):
        super()._ _init_ _()
        self.num_classes = num_classes
        # embed_dim = 768 是根据输入的image大小预先手动计算出来的
        self.d_model = self.num_features = self.embed_dim = embed_dim  

        self.if_rope = if_rope
        self.if_bidirectional = if_bidirectional

        self.patch_embedding_layer = PatchEmbed(img_size = img_size).to(device)
        grid_size = (img_size/patch_size) * (img_size/patch_size)

        self.pos_embed = torch.nn.Parameter(torch.zeros(size=(int(grid_size),embed_dim))).to(device)


        if if_rope:
            half_head_dim = embed_dim // 2
            hw_seq_len = img_size // patch_size
            self.rope_layer = position.VisionRotaryEmbeddingFast(
                dim=half_head_dim,
                pt_seq_len=32,
                ft_seq_len=hw_seq_len
            ).to(device)

        self.head = torch.nn.Linear(768, num_classes,device=device)

        self.mamba_blocks = [moudle.MambaBlock(d_model = embed_dim, state_size = 32,device=device) for _ in range(num_layers)]
        self.norm_f = torch.nn.LayerNorm(embed_dim,device=device)

    def forward(self,x):
        x = self.patch_embedding_layer(x) + self.pos_embed
        x += self.pos_embed

        B, M, _ = x.shape
        hidden_states = x
        if self.if_bidirectional:
            # get two layers in a single for-loop
            for i in range(len(self.mamba_blocks) // 2):
                hidden_states = self.rope_layer(hidden_states)
                # 第一次计算前向的内容
                hidden_states_forward = self.mamba_blocks[i * 2](hidden_states) 
                hidden_states_backward = self.mamba_blocks[i * 2 + 1](hidden_states.flip([1]))
                hidden_states = hidden_states_forward + hidden_states_backward.flip([1])
            hidden_states = self.mamba_blocks[-1](hidden_states)
        else:
            for block in self.mamba_blocks:
                hidden_states = self.rope_layer(hidden_states)
                hidden_states = block(hidden_states)

        hidden_states = self.norm_f(hidden_states)

        hidden_states = hidden_states.mean(dim=1)
        logits = self.head(hidden_states)

        return logits


if _ _name_ _ == '_ _main_ _':
    device = "cuda"
    image = torch.randn(size=(2,3,32,32)).to(device)
    output = VisionMamba(device=device)(image)
    print(output.shape)

7.3.2  VisionMamba图像分类实战

最后,我们将完成基于VisionMamba的图像分类实战,此时可以在可视化训练过程的基础上完成VisionMamba图像分类实战。代码如下:

import torch
import get_cifar10
import vision_mamba
from tqdm import tqdm

device = "cuda"
model = vision_mamba.VisionMamba().to(device)

# 定义优化器和损失函数
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1200, eta_min=2e-6, last_epoch=-1)
loss_func = torch.nn.CrossEntropyLoss()

# 设置训练参数
batch_size = 224

from torch.utils.data.dataloader import DataLoader
customer_dataset = get_cifar10.CustomerDataset()
dataloader = DataLoader(customer_dataset, batch_size=batch_size, shuffle=True)

for epoch in range(36):
    pbar = tqdm(dataloader, total=len(dataloader))# 初始化进度条,用于可视化训练进度
    for (batch_input_images, batch_labels) in pbar:# 遍历数据加载器中的每一批数据
        optimizer.zero_grad()
        batch_input_images = batch_input_images.float().to(device)
        batch_labels = batch_labels.to(device)

        # 进行前向传播和损失计算
        logits = model(batch_input_images)

        loss = loss_func(logits.view(-1, logits.size(-1)), batch_labels.view(-1))
        # 进行反向传播和优化
        loss.backward(retain_graph=True)
        optimizer.step()
        lr_scheduler.step()  # 执行优化器学习率更新

        accuracy = (logits.argmax(1) == batch_labels).type(torch.float32).sum() / batch_size
        # 更新进度条描述,显示当前epoch、训练损失和学习率
        pbar.set_description(
            f"epoch:{epoch + 1}, train_loss:{loss.item():.5f},train_accuracy:{accuracy.item():.2f}, lr:{lr_scheduler.get_last_lr()[0] * 100:.5f}")

    torch.save(model.state_dict(), "./saver/modelpara.pt")

这里需要注意的是,由于在不同的硬件资源条件下,对于一次性输入模型进行训练的批次,需要根据具体情况进行设置,训练和预测部分读者可以自行完成。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值