前言
首先需要将图片数据转化为一个序列数据,MNIST手写数字的图片大小是28×28,那么可以将每张图片看作是长为28的序列,序列中的每个元素的特征维度是28,这样就将图片变成了一个序列。同时考虑到循环神经网络的记忆性,所以图片从左往右输入网络的时候,网络可以记忆住前面观察到的东西,也就是说一张图片虽然被切割成了28份,但是网络能够通过记住前面的部分,同时和后面的部分结合 得到最后预测数字的输出结果,所以从理论上而言是行得通的。
实验代码
实验代码如下:
import torch
import torch.nn as nn
from torchvision import transforms as transform
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from datetime import datetime
from torch.autograd import Variable
data_tf = transform.Compose([
transform.ToTensor(),
transform.Normalize([0.5], [0.5])
])
train_dataset = MNIST('./data', train=True, download=True, transform=data_tf)
test_dataset = MNIST('./data', train=False, transform=data_tf)
train_data_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
test_data_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)
class rnn_classification(nn.Module):
def __init__(self, dim_in, dim_hidden, layer_num, n_class):
super