手撕Vision Transformer – Day5 – predict.py
Vit 网络结构图

Vit 网络结构
predict代码
Part1 库函数
'''
# Part1 引入相关的模型
'''
import torch
from dataset import Mnist_dataset
import matplotlib.pyplot as plt
Part2 初始化模型的一些参数
'''
# part2 下载模型
'''
net = torch.load('VIT_eopch_0.pt')
net.eval()
data_cs = Mnist_dataset(is_tran=False)
Part3 开始训练
'''
# Part3 开始测试
'''
if __name__ == '__main__':
img, label = data_cs[1]
label_predict = net(img.unsqueeze(0))
label_predict = torch.argmax(label_predict)
if label_predict == label:
print('真实的标签为{},预测的标签为{},预测正确'.format(label, label_predict))
else:
print('真实的标签为{},预测的标签为{},预测错误'.format(label, label_predict))
plt.imshow(img.permute(2,1,0))
plt.show()
参考
视频讲解:【Sora重要技术】复现ViT(Vision Transformer)模型_哔哩哔哩_bilibili
原理参考:手撕Vision Transformer – Day1 – 基础原理-优快云博客