统计学习笔记(二)感知机

第二章 感知机

2.1 感知机模型

假设输入空间(特征空间)是XRn,输出空间是y={+1, -1},输入x=X表示一个实例的特征向量,输出表示实例的类别。由输入空间到输出空间的函数

f(x)=sign(wx+b)
称为感知机。其中wRn是权值向量,bR是偏置,wx是w和x的内积。sign是符号函数
sign(x)={+1,1,x0x<0

感知机是一种线性分类模型,判别模型。感知机的假设空间是定义在特征空间中的函数集合
{f|f(x)=wx+b}

感知机可以解释如下:线性方程
wx+b=0
对应于特征空间Rn 中的一个超平面。其中w是超平面的法向量,b是超平面的截距。这个超平面将特征空间分为2部分,也将特征空间里的点分为正、负两部分。所以,超平面也叫分离超平面。

2.2 感知机学习策略

线性可分与非线性可分

数据集T的线性可分性与非线性可分性。其中

T={(x1,y1),(x2,y2),...(xN,yN)}xiXRn,yiY={+1,1},i=1,2,...N

感知机学习策略

定义损失函数
特征空间中任意选取一个点x0=(x(0)1,x(0)2,...x(0)N),点到超平面的距离

l=wx+bw

对于一个误分类的点(xi,yi),有yi(wxi+b)>0。因此,误分类点xi到超平面S的距离为
1wyi(wxi+b)

对于给定的测试数据集
T=((x1,y1),(x2,y2),...(xN,yN))

其中xiX=Rn,yiY={+1,1}
感知机学习的损失函数定义为
L(w,b)=xiMyi(wxi+b)

其中M是所有误分类点的集合。

2.3 感知机学习算法

2.3.1 算法的原型

minw,bL(w,b)

感知机学习算法是由误分类驱动的。 首先任意选取一个超平面(w0,b0),然后用梯度下降法不断的极小化目标函数L(w,b)。对于误分类点的集合M,损失函数L(w,b)的梯度
wL(w,b)=xiMyixibL(w,b)=xiMyi

随机选取一个误分类点(xi,yi),对w,b进行更新:
ww+ηyixibb+ηyi

式子中η(0<η1)是步长,在统计学中叫做学习率。

算法2.1 感知机学习的原始形式
输入:训练数据集T={(x1,y1),(x2,y2),...(xN,yN)}xiXRn,yiY={+1,1},i=1,2,...Nη(0<η1)
输出:w,bf(x)=sign(wx+b)
(1)选取初值:w=w0,b=b0
(2)在训练集T中选取数据(xi,yi)
(3)如果yi(wxi+b)0

ww+ηyixibb+ηyi

(4)转至(2),直到训练集中没有误分类点。
如下流程图。虽然画的丑了点儿,凑合看吧。
Created with Raphaël 2.1.0StartIsMisFound?Update(w,b)Overyesno

2.3.2

算法的收敛性

2.3.3 算法的对偶形式

对偶形式的运算相对的快?不过其实这也是没什么卵用。那么多的乘法,CPU肯定吃不消,GPU?算了开始笔记。

算法实现

import numpy as np
import matplotlib.pyplot as plt

global gw
global gb
global depth
global x
global y
# xt is the transposition of x. only used by pylib.
x = [ [3,3], [4,3], [1,1] ]
y = [ +1,    +1,    -1 ]
xt = list(zip(*x))
gw = None
gb = None
depth = 0
if len(x[0]):
    depth = len(x[0])

def initwb(w=[], b=0):
    # if len(w) < depth: fill 0
    # if len(w) = depth: ok
    # if len(w) > depth: cut w.
    global gw, gb, depth
    lenw = len(w)
    if lenw == 0:
        gw = [0] * depth
    else:
        if lenw < depth:
            gw = w[:]
            gw.extend([0]*(depth-lenw))
        elif lenw >= depth:
            gw = w[:depth]
    gb = b

'''
loss <= 0, wrongly classified
loss >0  , successfully classfied.
'''
def loss(i, w, b):
    global depth,x,y
    res = 0;
    assert len(x[i]) == depth, "Length of testdatum isn't equal to "+str(depth)
    res = np.multiply( y[i], np.dot(x[i],w)+b )
    return res

'''
update w,b using i-th data.
para1: i, index of data.
para2: w, matrix w with deep $depth$
para3: b, value of b
para4: eta, eta
'''
def updatewb(i, w, b, eta):
    global x, y, gw, gb
    res = None
    eta = 1
    gw = list(np.add( w, np.multiply(y[i], x[i]) ))
    gb = b + y[i]
    return (gw, gb)

'''
looking through all test data for a misclassified one. If found return the index;
if not return -1.
'''
def misclassified(w, b):
    res = -1
    for i in range(len(x)):
        if loss(i, w, b) <= 0:
            res = i
            break
    return res

def prtmodel():
    print("w is", gw, "; b is", gb, ".")


if __name__ == "__main__":
    initwb()
    loop = 100
    while loop >= 0:
        loop = loop - 1
        prtmodel()
        indx = misclassified(gw, gb)
        if indx != -1:
            #print("Index", indx, "is found.")
            updatewb(indx, gw, gb, 1)
        else:
            print("no misclassified data found!")
            break

'''
# pylot draws graph.
plt.figure(figsize=(6,6))
plt.plot( xt[0], xt[1], 'go', label='Test Data' )
plt.xlabel("$x^{(1)}$")
plt.ylabel("$x^{(2)}$")
plt.title("Machine Learning: Perceptron 1")
plt.xlim(0,5)
plt.ylim(-1,4)
ax = plt.gca()
ax.set_aspect(1)
plt.legend()
plt.show()
'''

======= RESTART: C:\Users\Public\Documents\KS\PY\ML\2.1_Perceptron.py =======
w:[0,0];b:0.f(x(1),x(2))=0x(1)+0x(2)+(0)
w:[3,3];b:1.f(x(1),x(2))=3x(1)+3x(2)+(1)
w:[2,2];b:0.f(x(1),x(2))=2x(1)+2x(2)+(0)
w:[1,1];b:1.f(x(1),x(2))=1x(1)+1x(2)+(1)
w:[0,0];b:2.f(x(1),x(2))=0x(1)+0x(2)+(2)
w:[3,3];b:1.f(x(1),x(2))=3x(1)+3x(2)+(1)
w:[2,2];b:2.f(x(1),x(2))=2x(1)+2x(2)+(2)
w:[1,1];b:3.f(x(1),x(2))=1x(1)+1x(2)+(3)
no misclassified data found!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值