偷得浮生半日闲:手写数字的识别

 本次我们利用拟利用三层神经网络来识别手写数字,利用http://pjreddie.com/projects/mnist-in-csv/提供的训练数字和测试数字,来进行网络的训练。隐藏层设置为100个节点,学习率设置为0.3,并将全部数据集训练5次。最终得到正确率为

import numpy
import scipy.special
class network:
    def __init__(self,input_nodes,hidden_nodes,output_nodes,learn):
        self.inodes=input_nodes
        self.hnodes=hidden_nodes
        self.onodes=output_nodes
        self.lr=learn
        self.w_ih=(numpy.random.normal(0.0,pow(self.hnodes,-0.5),(self.hnodes,self.inodes)))
        self.w_ho=(numpy.random.normal(0.0,pow(self.onodes,-0.5),(self.onodes,self.hnodes)))
    def sigmod_function(self,x):
        return scipy.special.expit(x)
    def query(self,input_list):
        input=numpy.array(input_list,ndmin=2).T
        hidden_input=numpy.dot(self.w_ih,input)
        hidden_out=self.sigmod_function(hidden_input)
        outre_input=numpy.dot(self.w_ho,hidden_out)
        outre_output=self.sigmod_function(outre_input)
        return outre_output
    def train(self,input_list,output_list):
        input=numpy.array(input_list,ndmin=2).T
        target=numpy.array(output_list,ndmin=2).T
        hidden_input=numpy.dot(self.w_ih,input)
        hidden_out=self.sigmod_function(hidden_input)
        outre_input=numpy.dot(self.w_ho,hidden_out)
        outre_output=self.sigmod_function(outre_input)
        output_errors=target-outre_output
        hidden_errors=numpy.dot(self.w_ho.T,output_errors)
        self.w_ho+=self.lr*numpy.dot(output_errors*outre_output*(1-outre_output),hidden_out.T)
        self.w_ih+=self.lr*numpy.dot(hidden_errors*hidden_out*(1-hidden_out),input.T)
input_nodes=784
output_nodes=10
hidden_nodes=100
learn=0.3
n=network(input_nodes,hidden_nodes,output_nodes,learn)
f=open("D:\\bpnetwork\\mnist_train.csv",'r')
train_data=f.readlines()    
f.close()
for record in train_data:
    all_values=record.split(',')
    input_list=(numpy.asfarray(all_values[1:])/255.0*0.99)+0.01
    output_list=numpy.zeros(output_nodes)+0.01
    output_list[int(all_values[0])]=0.99
    n.train(input_list,output_list)
f=open("D:\\bpnetwork\\mnist_test.csv",'r')
test_data=f.readlines()
f.close()
result=[]
for record in test_data:
    data=record.split(',')
    test_per=(numpy.asfarray(data[1:])/255.0*0.99)+0.01
    label_result=numpy.argmax(n.query(test_per))
    label_correct=int(data[0])
    if label_result==label_correct:
        result.append('1')
    else:
        result.append('0')
length=len(result)
i=0
cnt=0

while i<length:
    if result[i]=='1':
        cnt+=1
    i+=1
print(cnt/length)

正确率如图所示:
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值