手撕Vision Transformer -- Day2 -- Dataset

手撕Vision Transformer – Day2 – Dataset

Vit 网络结构图

在这里插入图片描述

Vit 网络结构

Dataset代码

Part1 库函数

# 该模块主要是为了嵌入获取相关的数据,得到dataloader等等,只不过把数据集的获取设计成了一个类
import torch
import torchvision
from torchvision import transforms
from torch import nn
# 总是容易忘记这个data的引用
import torch.utils.data as data

Part2 初始化一个数据集的类

class Mnist_dataset(data.Dataset):
    def __init__(self, is_tran):
        super().__init__()
        self.transform_action = transforms.Compose([
            transforms.ToTensor()
        ])
        self.Mnist = torchvision.datasets.MNIST(root='./Mnist_data', train=is_tran, transform=self.transform_action,
                                                download=True)

    def __len__(self):
        return len(self.Mnist)

    def __getitem__(self, index):
        return self.Mnist[index][0], self.Mnist[index][1]  # 返回对应的图像和label

Part3 测试

if __name__ == '__main__':
    import matplotlib.pyplot as plt

    plt.figure(figsize=(45, 45))
    mnist = Mnist_dataset(is_tran=True)
    img, label = mnist[0]
    # 因为Totensor,会把数据进行归一化,并且把pillow(imag_size,imag_size,channel)转化为(channel,imag_size,imag_size)
    # 为了能够画图,需要把channel换回来
    plt.imshow(img.permute(2, 1, 0))
    plt.title(f"Label: {label}")
    plt.show()

参考

视频讲解:【Sora重要技术】复现ViT(Vision Transformer)模型_哔哩哔哩_bilibili

原理参考:手撕Vision Transformer – Day1 – 基础原理-优快云博客

github资料:YanxinTong/VIT_Pytorch: 利用Pytorch 手撕 VIT 模型

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值