import numpy as np
from functools import reduce
class Perceptron(object):
def __init__(self, input_num, activator):
self.activator = activator
self.weights = [0.0 for _ in range(input_num)]
self.bias = 0.0
def __str__(self):
return 'weights\t:%s\nbias\t:%f\n' % (self.weights, self.bias)
def predict(self, input_vec):
return self.activator(
reduce(lambda a, b: a + b, [x_w[0] * x_w[1] for x_w in zip(input_vec, self.weights)], 0.0) + self.bias)
def train(self, input_vecs, labels, iteration, rate):
for i in range(iteration):
self._one_iteration(input_vecs, labels, rate)
def _one_iteration(self, input_vecs, labels, rate):
[(input_vec, label), ...]
samples = list(zip(input_vecs, labels))
for (input_vec, label) in samples:
output = self.predict(input_vec)
self._update_weights(input_vec, output, label, rate)
def _update_weights(self, input_vec, output, label, rate):
delta = label - output
self.weights = [x_w1[1] + rate * delta * x_w1[0] for x_w1 in zip(input_vec, self.weights)]
self.bias += rate * delta
def f(x):
return 1 if x > 0 else 0
f = lambda x: x
class LinearUnit(Perceptron):
def __init__(self, input_num):
Perceptron.__init__(self, input_num, f)
def get_training_dataset():
input_vecs = [[5], [3], [8], [1.4], [10.1]]
labels = [5500, 2300, 7600, 1800, 11400]
return input_vecs, labels
def train_linear_unit():
lu = LinearUnit(1)
input_vecs, labels = get_training_dataset()
lu.train(input_vecs, labels, 10, 0.01)
return lu
def plot(linear_unit):
import matplotlib.pyplot as plt
input_vecs, labels = get_training_dataset()
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter([x[0] for x in input_vecs], labels)
weights = linear_unit.weights
bias = linear_unit.bias
x = list(range(0,12,1))
y = [weights[0] * x + bias for x in x]
ax.plot(x, y)
plt.show()
if __name__ == '__main__':
linear_unit = train_linear_unit()
print (linear_unit)
print ('Work 3.4 years, monthly salary = %.2f' % linear_unit.predict([3.4]))
print ('Work 15 years, monthly salary = %.2f' % linear_unit.predict([15]))
print ('Work 1.5 years, monthly salary = %.2f' % linear_unit.predict([1.5]))
print ('Work 6.3 years, monthly salary = %.2f' % linear_unit.predict([6.3]))
plot(linear_unit)
线性单元和梯度下降--代码实现
最新推荐文章于 2022-06-16 17:38:44 发布