pytorch rnn 实现手写字体识别

本文介绍如何使用PyTorch构建RNN进行手写数字识别,包括数据加载、模型构建、训练及测试过程。通过MNIST数据集,演示了RNN在序列数据上的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

构建 RNN 代码


import  torch
import   torch.nn  as  nn
from  torch.autograd  import  Variable

import  torch.utils.data  as  Data

import   torchvision

import   matplotlib.pyplot  as  plt


torch.manual_seed(1)

#batch size
BATCH_SIZE=50
#学习率
LR= 0.001
DOWNLOAD=False
#是否训练
TRAIN =False



class   RNN(nn.Module):
    def __init__(self):
        super(RNN,self).__init__()


        '''
        input_size:输入特征的数目
        hidden_size:隐层的特征数目
        num_layers:这个是模型集成的LSTM的个数 记住这里是模型中有多少个LSTM摞起来 一般默认就1个
        #batch_first: 输入数据的size为[batch_size, time_step, input_size]还是[time_step, batch_size, input_size]
       '''
        self.rnn= nn.LSTM(
            input_size=28,
            hidden_size=64,
            num_layers=3,
            batch_first=True #batch_first: 输入数据的size为[batch_size, time_step, input_size]还是[time_step, batch_size, input_size]

        )

        self.out = nn.Linear(64,10)

        self.optimizer = torch.optim.Adam(self.parameters(),lr=LR)
        self.lossFunc= nn.CrossEntropyLoss()


    def forward(self,x):

        #x [ batch,28,28]

        r_out ,(h_n,h_c)= self.rnn(x,None)

        #r_out [50,28,64]   h_n=[1,50,64]  h_c =[1,50,64]

        #r_out  表示 每一次输入  28 个像素  输入了  50* 28 次

        #h_n 表示    每 28*28 为一次 记录  隐藏层 为 64 所以为  50,64  每28*28为一个记录 参数

        print(r_out.size(), h_n.size(),h_c.size())

        r_out = self.out(r_out[:,-1,:])

        return  r_out


    def  lossFunction(self,predict ,batchY):
        loss = self.lossFunc(predict,batchY)
        self.optimizer.zero_grad()
        loss.backward()
        print("loss==",loss.data)
        self.optimizer.step()

加载数据

tranData =  torchvision.datasets.MNIST(
    root="d:/mnist/",
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD

)

testData = torchvision.datasets.MNIST(
    root="d:/mnist/",
    train=False
)

trainLoader =  Data.DataLoader(dataset=tranData,
                               batch_size=BATCH_SIZE,
                               shuffle=True
                               )

# 为了节约时间, 我们测试时只测试前2000个
# shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_x = Variable(torch.unsqueeze(testData.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255.
test_y = testData.test_labels[:2000]

使用RNN 训练 和测试数据


#构造RNN
myRNN= RNN()


#训练数据
if(TRAIN):


    for  epoch in  range(3):
      for  step  ,(x,y) in enumerate(trainLoader):
            trainX = Variable(x.view(-1,28,28))
            print("trainX==",trainX.size())
            tranY = Variable(y)
            predict= myRNN(trainX)
            print("predict==",predict)
            myRNN.lossFunction(predict,tranY)

    torch.save(myRNN.state_dict(), "d:/mnist/rnn.pkl")
else:

    myRNN.load_state_dict(torch.load("d:/mnist/rnn.pkl"))




#测试数据
testOut = myRNN(test_x[:20].view(-1,28,28))

print("testOut==",testOut.size())
#预测值
testPredict = torch.max(testOut,1)[1]

print("testPredict==", testPredict.size())
print(testPredict,test_y[:20])

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值