元洁的python之神经网络学习(一)
#!/user/bin/env python
import numpy
import scipy.special
class neuralNetwork:
def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
"""
初始化函数:初始化神经网络相关数据
"""
# 输入节点、隐藏层节点、输出节点
self.inodes = inputnodes
self.hnodes = hiddennodes
self.onodes = outputnodes
# 学习率
self.lr = learningrate
# 权重矩阵 (初始值应该较小,并且是随机的)
#self.wih=(numpy.random.rand(self.hnodes ,self.inodes)-0.5)
#self.who=(numpy.random.rand(self.onodes ,self.hnodes)-0.5)
# 正态概率分布初始化权重
self.wih = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))
self.who = numpy.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))
# 神经网络激活函数sigmoid
self.activation_function = lambda x:scipy.special.expit(x)
pass
def train(self, inputs_lists, targets_list):
# 第一部分:针对给定的训练样本计算输出
inputs = numpy.array(inputs_lists, ndmin=2).T
targets = numpy.array(targets_list, ndmin=2).T
hidden_inputs = numpy.dot(self.wih, inputs)
hidden_outputs = self.activation_function(hidden_inputs)
final_inputs = numpy.dot(self.who, hidden_outputs)
final_outputs = self.activation_function(final_inputs)
# 计算所得输出与目标值对比,使用差值来指导网络权重的更新
output_errors = targets - final_outputs
hidden_errors = numpy.dot(self.who.T, output_errors)
input_errors = numpy.dot(self.wih.T, hidden_errors)
# 隐藏层与输出层之间的权重更新
self.who = self.who + self.lr * numpy.dot(output_errors * final_outputs * (1.0 - final_outputs),
numpy.transpose(hidden_outputs))
self.wih = self.wih + self.lr * numpy.dot(hidden_errors * hidden_outputs * (1.0 - hidden_outputs),
numpy.transpose(inputs))
pass
def query(self, inputs_lists):
"""
该函数的作用:计算神经网络的输出
"""
# 输入:输入的列表,转换为维度为2的数组后,再进行转置
inputs = numpy.array(inputs_lists, ndmin=2).T
# 隐藏层的输入与输出
hidden_inputs = numpy.dot(self.wih, inputs)
hidden_outputs = self.activation_function(hidden_inputs)
# 最终层的输入与输出
final_inputs = numpy.dot(self.who, hidden_outputs)
final_outputs = self.activation_function(final_inputs)
return final_outputs
pass
pass
input_nodes = 3
hidden_nodes = 3
output_nodes = 3
learning_rate = 0.3
n=neuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)