理论知识见https://www.zybuluo.com/hanbingtao/note/448086
直接上python3的代码
#coding utf-8
import matplotlib.pyplot as plt
from functools import reduce
class perceptron(object):
#初始化,输入训练数目,激活函数
def __init__(self,input_num,activator):#activator为激活函数
self.activator=activator
self.weights=[0.0 for _ in range(input_num)]#权重初始化为0
self.bias=0.0#偏置初始化为0.0
#运算
def operation(self,input_vec):
#对激活函数中的参数做运算,x[0]代表input_vec,x[1]代表weights
return self.activator(reduce(lambda a,b:a+b,map(lambda x:x[0]*x[1],zip(input_vec,self.weights)),0.0)+self.bias)#0.0为reduce的初始计算值
#权值更新
def updata(self,input_vec,output,label,rate):
delta=label-output
self.weights=list(map(lambda x:x[1]+rate*delta*x[0],zip(input_vec,self.weights)))#加上list跟python2有区别
self.bias+=rate*delta
#训练,输入数据及对应标签,迭代次数,学习率
def train(self,input_vecs,labels,iteration_num,rate):
for i in range(iteration_num):#iteration_num次迭代
samples=zip(input_vecs,labels)#打包
for (input_vec,label) in samples:
output=self.operation(input_vec)#计算输出值
self.updata(input_vec,output,label,rate)#更新
#预测
def predict(self,input_vec):
return self.operation(input_vec)
#打印权值,偏置
def __str__(self):#内部函数
return "weight: %s, bias: %f"%(self.weights,self.bias)#权值返回用%s
'''实现线性单元'''
#激活函数为线性函数
andActivator=lambda x:x
#得到训练数据
def getTrainData():
input_vecs=[[5],[3],[8],[1.4],[10.1]]#可重用多次循环迭代
labels=[5500,2300,7600,1800,11400]
return input_vecs,labels
#训练感知机
def trainPerceptron():
p=perceptron(1,andActivator)
input_vecs,labels=getTrainData()
p.train(input_vecs,labels,30,0.1)#100为迭代次数,0.1为学习率
return p
#画图
def plot(linearUnit):
input_vecs,labels=getTrainData()
fig=plt.figure()
ax=fig.add_subplot(111)
ax.scatter(input_vecs,labels)#横坐标input_vecs,纵坐标labels
weights=linearUnit.weights
bias=linearUnit.bias
x=range(0,15,1)#画0到15年的图像
y=list(map(lambda x:weights[0]*x+bias,x))
ax.plot(x,y)
plt.show()
#主函数
if __name__=='__main__':
train_perceptron=trainPerceptron()
print(train_perceptron)
#测试
print('工作3.4年,月薪%f'%train_perceptron.predict([3.4]))
print('工作15年,月薪%f'%train_perceptron.predict([15]))
print('工作1.5年,月薪%f'%train_perceptron.predict([1.5]))
print('工作6.3年,月薪%f'%train_perceptron.predict([6.3]))
plot(train_perceptron)
拟合得不好,见谅