基于LSTM实现mnist手写数字识别

本文介绍了如何利用LSTM模型对mnist手写数字数据集进行图像分类。通过读取28*28像素的图像,设置LSTM网络结构,其中输入到隐藏层有28*128个参数,隐藏层到输出层有128*10个参数。LSTM结构被设计为每个图像行作为一次输入,训练过程中循环28次。最终采用特定的优化算法和损失函数进行模型训练。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

首先读取数据,数据源是mnist库,可以通过input_data中read_data函数直接读取数据,数据图像为28*28。

#导入库
import tensorflow as tf
#下载数据对应的库
import input_data
import numpy as np
import matplotlib.pyplot as plt
print ("Packages imported")

#导入mnist数据
mnist = input_data.read_data_sets("data/", one_hot=True)#one_hot=True 表示 数据的标签是one_hot编码的,即数据标签为1*10的数组
#读取训练数据,训练标签,测试数据,测试标签
trainimgs, trainlabels, testimgs, testlabels \
 = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels 
#获取训练数据个数,测试集数据个数,图像维度和类别数
ntrain, ntest, dim, nclasses \
 = trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
print ("MNIST loaded")

读取数据后设置参数。本次使用LSTM作为训练模型,因此需要搭建LSTM,因图像为28*28,所以将每一行图像作为一次输入,这样每一次训练,LSTM需要运算28次,设置隐层为128,所以从输入到隐层的全连接参数为28*128个,经过运算后输出与隐层全连接参数为128*10,每运行一次的输出为1*10.

LSTM结构如图:

搭建了28个lstm,每一个lstm公用参数,因此也可以看做搭建了一个lstm循环了28次,只有最终的结果有作用。

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值