手撕Vision Transformer -- Day5 -- predict.py

手撕Vision Transformer – Day5 – predict.py

Vit 网络结构图

在这里插入图片描述

Vit 网络结构

predict代码

Part1 库函数

# 该模块主要是为了预测分类,输入一个图像得到一个类别
'''
# Part1 引入相关的模型
'''
import torch
from dataset import Mnist_dataset
import matplotlib.pyplot as plt

Part2 初始化模型的一些参数

'''
# part2 下载模型
'''
net = torch.load('VIT_eopch_0.pt')
net.eval()
data_cs = Mnist_dataset(is_tran=False)

Part3 开始训练

'''
# Part3 开始测试
'''
if __name__ == '__main__':
    img, label = data_cs[1]
    label_predict = net(img.unsqueeze(0))
    label_predict = torch.argmax(label_predict)
    if label_predict == label:
        print('真实的标签为{},预测的标签为{},预测正确'.format(label, label_predict))
    else:
        print('真实的标签为{},预测的标签为{},预测错误'.format(label, label_predict))
    # 开始绘制图像
    plt.imshow(img.permute(2,1,0))
    plt.show()

参考

视频讲解:【Sora重要技术】复现ViT(Vision Transformer)模型_哔哩哔哩_bilibili

原理参考:手撕Vision Transformer – Day1 – 基础原理-优快云博客

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值