读取原始数据
import pandas as pd
import numpy as np
in_data = pd.read_table('./origin-data/perceptron_15.dat', sep='\s+', header=None)
X_train = np.array(in_data.loc[:,[0,1,2,3]])
y_train = np.array(in_data[4])
训练感知机模型
class MyPerceptron:
def __init__(self):
self.w = None
self.b = 0
self.l_rate = 1
def fit(self, X_train, y_train):
#用样本点的特征数更新初始w,如x1=(3,3)T,有两个特征,则self.w=[0,0]
self.w = np.zeros(X_train.shape[1])
i = 0
while i < X_train.shape[0]:
X = X_train[i]
y = y_train[i]
# 如果y*(wx+b)≤0 说明是误判点,更新w,b
if y * (np.dot(self.w, X) + self.b) <= 0:
self.w += self.l_rate * np.dot(y, X)
self.b += self.l_rate * y
i=0 #如果是误判点,从头进行检测
else:
i+=1