感知机算法及其实现

本文深入解析感知机算法,一种用于二分类的线性模型。详细介绍了感知机模型的工作原理,包括其几何解释、数据线性可分性、损失函数定义以及学习算法流程。并提供了Python实现示例,帮助读者理解感知机如何通过训练数据找到最优分类超平面。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

感知机算法及其实现

感知机(perceptron)是二分类的线性分类模型。输入实例的特征向量,输出实例的类别(+1和-1)。感知机属于判别模型,目的是求出将训练数据线性划分的分离超平面。感知机模型是一种线性分类模型,是支持向量机(SVM)的基础。

1. 感知机模型

假设输入空间(特征空间)是χ⊆Rn\chi \subseteq R^nχRn, 输出空间是y=+1,−1,y={+1,-1,}y=+1,1,. 输入表示实例的特征向量,输出表示实例的类别,由输入空间到输出空间的函数为:

f(x)=sign(W∗x+b)(1) f(x)=sign(W*x+b)\tag{1} f(x)=sign(Wx+b)(1)

通过训练数据算出WWWbbb,即权重(weight)和偏置(bias),使得学习到的感知机模型能对新的输入实例输出分类的类别。

1.1 几何解释

求解感知机模型的过程即找到特征空间的分离超平面(separating hyperplane)SSS,将输入实例分开,其中WWW是超平面的法向量,bbb是超平面的截距。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Rt9Yqg52-1647946576109)(_v_images/1539441161_26346.png)]
在这里插入图片描述

2. 数据的线性可分性

给定数据集

T={(x1,y1),(x2,y2),...,(xn,yn)},(2) T=\{(x_1,y_1),(x_2,y_2),...,(x_n,y_n)\},\tag{2} T={(x1,y1),(x2,y2),...,(xn,yn)},(2)
若存在超平面SSS

W∗x+b=0(3) W*x+b=0\tag{3} Wx+b=0(3)
能够将数据集所有的正实例和负实例完全正确的划分到超平面的两侧,即对所有数据

yi(Wxi+b)>0(4) y_{i}(Wx_{i}+b)>0\tag{4} yi(Wxi+b)>0(4)
则称数据集TTT是线性可分数据集(linearly separable data set);否则称数据集是线性不可分。
(PS. 若式4取等号,即代表有实例点在超平面上)
举个栗子,异或函数(XOR)产生的数据集是线性不可分的。(画下图就可知了)

3. 损失函数

输入空间中任一个点x0x_{0}x0到超平面SSS的距离为:

1∣∣W∣∣∣W∗x0+b∣(5) \frac{1}{||W||}|W*x_{0}+b|\tag{5} W1Wx0+b(5)
对于误分类点:

−yi(Wxi+b)<0(5) -y_{i}(Wx_{i}+b)<0\tag{5} yi(Wxi+b)<0(5)
所以误分类点xix_{i}xi到超平面的距离是:

1∣∣W∣∣yi(W∗xi+b)(6) \frac{1}{||W||}y_{i}(W*x_{i}+b)\tag{6} W1yi(Wxi+b)(6)
那么所有的误分类点到超平面SSS的总距离为:

−1∣∣W∣∣∑xi∈Myi(W∗xi+b)(7) -\frac{1}{||W||}\sum_{x_{i} \in M}y_{i}(W*x_{i}+b)\tag{7} W1xiMyi(Wxi+b)(7)
损失函数定义为:

L(W,b)=−∑xi∈Myi(W∗xi+b)(8) L(W,b)=-\sum_{x_{i} \in M}y_{i}(W*x_{i}+b)\tag{8} L(W,b)=xiMyi(Wxi+b)(8)
损失函数是非负的,当没有误分类点时,损失函数最小为0. 损失函数L(W,b)L(W,b)L(W,b)WWW,bbb的连续可导函数。
显然,感知机学习算法是误分类驱动的。
损失函数L(W,b)L(W,b)L(W,b)的梯度由:

∇WL(W,b)=−∑xi∈Myixi(9) \nabla_WL(W,b)=-\sum_{x_{i} \in M}y_ix_i\tag{9} WL(W,b)=xiMyixi(9)
∇bL(W,b)=−∑xi∈Myi(10) \nabla_bL(W,b)=-\sum_{x_{i} \in M}y_i\tag{10} bL(W,b)=xiMyi(10)
给出,SGD优化:随机选取一个误分类点(xi,yi)(x_i,y_i)(xi,yi),对WWWbbb进行更新:

W←W+ηyixi (11) W \leftarrow W+ \eta y_ix_i\ \tag{11} WW+ηyixi (11)
b←b+ηxi (12) b \leftarrow b+ \eta x_i\ \tag{12} bb+ηxi (12)

其中,η\etaη是步长,0≤η≤10 \leq \eta \leq 10η1,又称为学习率(learning rate).

4. 感知机学习算法的形式

输入:训练数据集T={(x1,y1),(x2,y2),...,(xn,yn)}T=\{(x_1,y_1),(x_2,y_2),...,(x_n,y_n)\}T={(x1,y1),(x2,y2),...,(xn,yn)},其中xi∈X=Rnx_i \in X=R^nxiX=Rn, yi∈+1,−1y_i \in {+1,-1}yi+1,1,i=1,2,...,Ni=1,2,...,Ni=1,2,...,N,学习率为η\etaη
输出:W, b;感知机模型f(x)=sign(W∗x+b)f(x)=sign(W*x+b)f(x)=sign(Wx+b)
(1)选取初始值W0W_0W0,bbb
(2)在训练集中选取数据(xi,yi)(x_i,y_i)(xi,yi)
(3)如果yi(Wxi+b)≤0y_{i}(Wx_{i}+b) \leq 0yi(Wxi+b)0
W←W+ηyixi W \leftarrow W+ \eta y_ix_i WW+ηyixi
b←b+ηxi b \leftarrow b+ \eta x_{i} bb+ηxi
(4)转至(2),直至训练集中没有误分类点。

5. python实现感知机算法

Implement the perceptron learning algorithm by yourself. Use the dataset in Table 2 to test your algorithm.

Table2

xix_{i}xi(1,2)(2,3)(3,3)(2,1)(3,2)
y111-1-1

思路:

is_error_function: 计算

−yi(W∗xi+b)(13) -y_{i}(W*x_{i}+b)\tag{13} yi(Wxi+b)(13)

loss_function:计算损失函数,其中MMM为误分类点的集合

−∑xi∈Myi(W∗xi+b)(14) -\sum_{x_{i} \in M}y_{i}(W*x_{i}+b)\tag{14} xiMyi(Wxi+b)(14)

SGD构建优化问题,随机梯度下降来优化:

minW,bL(W,b)=−∑xi∈Myi(W∗xi+b)(15) \underset{W,b}{min}L(W,b)=-\sum_{x_{i} \in M}y_{i}(W*x_{i}+b)\tag{15} W,bminL(W,b)=xiMyi(Wxi+b)(15)

import numpy as np
import matplotlib.pyplot as plt

def is_error_function(W,b,X,y):
    value = np.dot(W,X)+b   #np.dot:inner product
    value = -value*y
    return value

def loss_function(W,b,X,Y):
    sum = 0
    for i in range(len(Y)):
        if is_error_function(W,b,X[i],Y[i])>=0:
            sum += is_error_function(W,b,X[i],Y[i])
    return sum

###   SGD   ###
def SGD(W,b,X,Y,learning_rate,limit,max_iter):
    cnt = 0
    while cnt < max_iter:
        for i in range(len(Y)):
            if (is_error_function(W,b,X[i],Y[i]) >= 0):
                #print('the loss is',is_error_function(W,b,X[i],Y[i]))
                #print('the error point is','(',X[i],',',Y[i],')')
                
                print 'the error point is','(',X[i],',',Y[i],')'
                print 'the loss is',is_error_function(W,b,X[i],Y[i])
                W=np.multiply(learning_rate*Y[i],X[i])+W
                b+=learning_rate*Y[i]
                #print(W,b)
                break
            else:
                print 'this point ',X[i],' do not need to update!'
        print(W,b)
        if loss_function(W,b,X,Y) < limit:
            print 'this Algotithm is convergence!!'
            print 'loss is',loss_function(W,b,X,Y)
            for i in range(len(Y)):    
                if Y[i] == 1:
                    plt.plot(X[i][0],X[i][1],'*b')
                else:
                    plt.plot(X[i][0],X[i][1],'or')
            plt.xlim(0,3.5)
            #plt.ylim(0,3.5)
            x = np.linspace(0,3.5,100)
            y = [0.0]*100
            for i in range(len(x)):
                y[i] = (-b-x[i]*W[0])/W[1]
            plt.plot(x,y)
            plt.show() 
            print '--------------------------------------------------'
            break
        else:     
            print 'continue...'
        for i in range(len(Y)):    
            if Y[i] == 1:
                plt.plot(X[i][0],X[i][1],'*b')
            else:
                plt.plot(X[i][0],X[i][1],'or')
        plt.xlim(0,3.5)
        #plt.ylim(0,3.5)
        x = np.linspace(0,3.5,100)
        y = [0.0]*100
        for i in range(len(x)):
            y[i] = (-b-x[i]*W[0])/W[1]
        plt.plot(x,y)
        plt.show() 
        print '--------------------------------------------------'
    
        cnt += 1
    
if __name__ == '__main__':
    # W = [theta1,theta2,theta0]
    W = [-1.0,1.0]
    b=0
    X = [[1,2],[2,3],[3,3],[2,1],[3,2]]
    Y = [1,1,1,-1,-1]
    SGD(W,b,X,Y,0.1,0.0001,10)

Results:

this point [1, 2] do not need to update!

this point [2, 3] do not need to update!

the error point is ( [3, 3] , 1 )

the loss is -0.0

(array([-0.7, 1.3]), 0.1)

continue…

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tkIFDWNa-1647946576110)(_v_images/1539437929_8708.png)]
在这里插入图片描述

this point [1, 2] do not need to update!

this point [2, 3] do not need to update!

this point [3, 3] do not need to update!

the error point is ( [2, 1] , -1 )

the loss is 1.3877787807814457e-16

(array([-0.9, 1.2]), 0.0)

this Algotithm is convergence!!

loss is 0
在这里插入图片描述
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-QX8YQonN-1647946576110)(_v_images/1539438080_1727.png)]

参考文献

李航. 统计学习方法.清华出版社

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值