使用transfomrer的编码器实现minist数字识别
1.数据预处理
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
2. 定义transformer模型
class TransformerModel(nn.Module):
def __init__(self, input_dim, num_classes,n_heads=4,num_encoder_layers=3):
super(TransformerModel, self).__init__()
self.positional_encoding=nn.Parameter(torch.zeros(1, 28, 28))
encoder_layers=self.transformer=nn.TransformerEncoderLayer(d_model=28,nhead=n_heads)
self.transformer_encoder = nn.T