本次我们利用拟利用三层神经网络来识别手写数字,利用http://pjreddie.com/projects/mnist-in-csv/提供的训练数字和测试数字,来进行网络的训练。隐藏层设置为100个节点,学习率设置为0.3,并将全部数据集训练5次。最终得到正确率为
import numpy
import scipy.special
class network:
def __init__(self,input_nodes,hidden_nodes,output_nodes,learn):
self.inodes=input_nodes
self.hnodes=hidden_nodes
self.onodes=output_nodes
self.lr=learn
self.w_ih=(numpy.random.normal(0.0,pow(self.hnodes,-0.5),(self.hnodes,self.inodes)))
self.w_ho=(numpy.random.normal(0.0,pow(self.onodes,-0.5),(self.onodes,self.hnodes)))
def sigmod_function(self,x):
return scipy.special.expit(x)
def query(self,input_list):
input=numpy.array(input_list,ndmin=2).T
hidden_input=numpy.dot(self.w_ih,input)
hidden_out=self.sigmod_function(hidden_input)
outre_input=numpy.dot(self.w_ho,hidden_out)
outre_output=self.sigmod_function(outre_input)
return outre_output
def train(self,input_list,output_list):
input=numpy.array(input_list,ndmin=2).T
target=numpy.array(output_list,ndmin=2).T
hidden_input=numpy.dot(self.w_ih,input)
hidden_out=self.sigmod_function(hidden_input)
outre_input=numpy.dot(self.w_ho,hidden_out)
outre_output=self.sigmod_function(outre_input)
output_errors=target-outre_output
hidden_errors=numpy.dot(self.w_ho.T,output_errors)
self.w_ho+=self.lr*numpy.dot(output_errors*outre_output*(1-outre_output),hidden_out.T)
self.w_ih+=self.lr*numpy.dot(hidden_errors*hidden_out*(1-hidden_out),input.T)
input_nodes=784
output_nodes=10
hidden_nodes=100
learn=0.3
n=network(input_nodes,hidden_nodes,output_nodes,learn)
f=open("D:\\bpnetwork\\mnist_train.csv",'r')
train_data=f.readlines()
f.close()
for record in train_data:
all_values=record.split(',')
input_list=(numpy.asfarray(all_values[1:])/255.0*0.99)+0.01
output_list=numpy.zeros(output_nodes)+0.01
output_list[int(all_values[0])]=0.99
n.train(input_list,output_list)
f=open("D:\\bpnetwork\\mnist_test.csv",'r')
test_data=f.readlines()
f.close()
result=[]
for record in test_data:
data=record.split(',')
test_per=(numpy.asfarray(data[1:])/255.0*0.99)+0.01
label_result=numpy.argmax(n.query(test_per))
label_correct=int(data[0])
if label_result==label_correct:
result.append('1')
else:
result.append('0')
length=len(result)
i=0
cnt=0
while i<length:
if result[i]=='1':
cnt+=1
i+=1
print(cnt/length)
正确率如图所示: