算法流程
输入:训练数据集T={(x1,y1),(x2,y2),⋅⋅⋅,(xN,yN),}T= \left\{ (x_1,y_1), (x_2,y_2),···,(x_N,y_N),\right\}T={(x1,y1),(x2,y2),⋅⋅⋅,(xN,yN),},其中xi∈χ=Rnx_i \in\chi=\mathbf{R}^nxi∈χ=Rn,yi∈Y={−1,+1},i=1,2,⋅⋅⋅,Ny_i\in Y=\left\{-1,+1\right\},i=1,2,···,Nyi∈Y={−1,+1},i=1,2,⋅⋅⋅,N;学习率η(0<η≤1)\eta(0<\eta \le1)η(0<η≤1);
输出:w,bw,bw,b;感知机模型f(x)=sign(w⋅x+b)f(x)=sign(w·x+b)f(x)=sign(w⋅x+b)。
-
解的过程:
(1)选取初值w0,b0w_0,b_0w0,b0;
(2)在训练集中选取数据(xi,yi)(x_i,y_i)(xi,yi);
(3)如果yi(w⋅xi+b)≤0y_i(w·x_i+b)\le0yi(w⋅xi+b)≤0,w←w+ηyixiw\gets w+\eta y_ix_iw←w+ηyixi b←b+ηyib\gets b+\eta y_ib←b+ηyi
(4)转至(2),直至训练集中没有误分类点 -
注释:当一个点实例点被误分类,即位于分离超平面的错误一侧时,则调整w,bw,bw,b的值,使分离超平面向该分类点的一侧移动,以减少该误分类点与超平面间的距离,直至超平面超过误分类点使其被正确分类。
算法示例
例2.1:在训练集中,其正实例点是x1=(3,3)Tx_1=(3,3)^Tx1=(3,3)T,x2=(4,3)Tx_2=(4,3)^Tx2=(4,3)T,其负实例点是x3=(1,1)Tx_3=(1,1)^Tx3=(1,1)T,试用感知机学习算法的原始形式求感知机模型f(x)=sign(w⋅x+b)f(x)=sign(w·x+b)f(x)=sign(w⋅x+b)。这里,w=(w(1),w(2))Tw=(w^{(1)},w^{(2)})^Tw=(w(1),w(2))T,x=(x(1),x(2))Tx=(x^{(1)},x^{(2)})^Tx=(x(1),x(2))T。
-
思路:
构建最优化问题:minw,bL(w,b)=−∑xi∈Myi(w⋅xi+b)\min_{w,b}L(w,b)=-\sum_{x_i \in M}y_i(w·x_i+b)w,bminL(w,b)=−xi∈M∑yi(w⋅xi+b)
按照上述算法流程求解w,b。η=1w,b。\eta=1w,b。η=1。 -
解:
(1)取初值w0=0,b0=0w_0=0,b_0=0w0=0,b0=0
(2)取点x1=(3,3)T,y1(w0⋅x1+b0)=0x_1=(3,3)^T,y_1(w_0·x_1+b_0)=0x1=(3,3)T,y1(w0⋅x1+b0)=0,即满足yi(w⋅xi+b)≤0y_i(w·x_i+b)\le0yi(w⋅xi+b)≤0,未能被正确分类,故更新w,bw,bw,b w1=w0+y1x1=(3,3)T,b1=b0+1w_1=w_0+y_1x_1=(3,3)^T,b_1=b_0+1w1=w0+y1x1=(3,3)T,b1=b0+1
得到线性模型:w1⋅x+b1=[33]⋅x+1=3x(1)+3x(2)+1w_1·x+b_1=\begin{bmatrix} 3 \\ 3 \end{bmatrix}·x +1=3x^{(1)}+3x^{(2)}+1w1⋅x+b1=[33]⋅x+1=3x(1)+3x(2)+1
(3)取点x1,x2x_1,x_2x1,x2,显然,yi(w⋅xi+b)>0y_i(w·x_i+b)>0yi(w⋅xi+b)>0,即被正确分类,不修改w,bw,bw,b ;取点x3=(1,1)T,y3(w1⋅x3+b1)<0x_3=(1,1)^T,y_3(w_1·x_3+b_1)<0x3=(1,1)T,y3(w1⋅x3+b1)<0,即满足满足yi(w⋅xi+b)≤0y_i(w·x_i+b)\le0yi(w⋅xi+b)≤0,未能被正确分类,故更新w,bw,bw,b
w2=w1+y3x3=[33]+(−1)⋅[11]=[22]=(2,2)Tw_2=w_1+y_3x_3=\begin{bmatrix} 3 \\ 3 \end{bmatrix}+(-1)·\begin{bmatrix} 1 \\ 1 \end{bmatrix}=\begin{bmatrix} 2 \\ 2 \end{bmatrix}=(2,2)^Tw2=w1+y3x3=[33]+(−1)⋅[11]=[22]=(2,2)T
b2=b1+y3=1+(−1)=0b_2=b_1+y_3=1+(-1)=0b2=b1+y3=1+(−1)=0
得到线性模型:
w2⋅x+b2=[22]⋅x+1=2x(1)+2x(2)+1w_2·x+b_2=\begin{bmatrix} 2 \\ 2 \end{bmatrix}·x +1=2x^{(1)}+2x^{(2)}+1w2⋅x+b2=[22]⋅x+1=2x(1)+2x(2)+1
(4)每次更新w,bw,bw,b 就要从新遍历整个训练集,如此继续下去,直到
w7=(1,1)T,b7=−3w_7=(1,1)^T,b_7=-3w7=(1,1)T,b7=−3
w7⋅x+b7=[11]+(−3)=x(1)+x(2)−3w_7·x+b_7=\begin{bmatrix} 1 \\ 1 \end{bmatrix} +(-3)=x^{(1)}+x^{(2)}-3w7⋅x+b7=[11]+(−3)=x(1)+x(2)−3
此时,对所有数据点yi(wy⋅xi+b)>0y_i(w_y·x_i+b)>0yi(wy⋅xi+b)>0,即没有误分类点,损失函数达到极小。
分离超平面为:x(1)+x(2)−3=0x^{(1)}+x^{(2)}-3=0x(1)+x(2)−3=0
感知机模型为:f(x)=sign(x(1)+x(2)−3)f(x)=sign(x^{(1)}+x^{(2)}-3)f(x)=sign(x(1)+x(2)−3) -
求解的迭代过程
迭代次数 | 误分类点取值顺序 | www | bbb | w⋅x+bw·x+bw⋅x+b |
---|---|---|---|---|
0 | 0 | 0 | 0 | |
1 | x1x_1x1 | (3,3)T(3,3)^T(3,3)T | 1 | 3x(1)+3x(2)+13x^{(1)}+3x^{(2)}+13x(1)+3x(2)+1 |
2 | x3x_3x3 | (2,2)T(2,2)^T(2,2)T | 1 | 2x(1)+x(2)2x^{(1)}+x^{(2)}2x(1)+x(2) |
3 | x3x_3x3 | (1,1)T(1,1)^T(1,1)T | 1 | x(1)+x(2)−1x^{(1)}+x^{(2)}-1x(1)+x(2)−1 |
4 | x3x_3x3 | (0,0)T(0,0)^T(0,0)T | 1 | −2-2−2 |
5 | x1x_1x1 | (3,3)T(3,3)^T(3,3)T | 1 | 3x(1)+3x(2)−13x^{(1)}+3x^{(2)}-13x(1)+3x(2)−1 |
6 | x3x_3x3 | (2,2)T(2,2)^T(2,2)T | 1 | 2x(1)+x(2)−22x^{(1)}+x^{(2)}-22x(1)+x(2)−2 |
7 | x3x_3x3 | (1,1)T(1,1)^T(1,1)T | 1 | x(1)+x(2)−3x^{(1)}+x^{(2)}-3x(1)+x(2)−3 |
8 | 000 | (1,1)T(1,1)^T(1,1)T | 1 | x(1)+x(2)−3x^{(1)}+x^{(2)}-3x(1)+x(2)−3 |
- 注:上述是在计算中误分类点先后取x1,x3,x3,x3,,x1,x3,x3x_1,x_3,x_3,x_3,,x_1,x_3,x_3x1,x3,x3,x3,,x1,x3,x3得到的分离超平面和感知机;如果在计算中误分类点先后取x1,x3,x3,x3,,x2,x3,x3,x3,x1,x3,x3x_1,x_3,x_3,x_3,,x_2,x_3,x_3,x_3,x_1,x_3,x_3x1,x3,x3,x3,,x2,x3,x3,x3,x1,x3,x3得到的分离超平面是2x(1)+x(2)−52x^{(1)}+x^{(2)}-52x(1)+x(2)−5
算法的代码实现
import numpy.matlib
import numpy as np
w = np.zeros((1,2))
print(w)
b = 0
print(b)
while True:
for index in data:
x=0
y=0
for in_data in index:
print("++++++")
x = in_data
y = index[in_data]
print(x)
value = np.array(x).reshape(2,1)
print(value)
print(np.dot(w,value))
f = y*(np.dot(w,value) + b)
print(f)
if f[0][0]<=0:
w = w + y*np.array(x)
b = b+y
print(w,b)
flg = 1
break
if flg == 1:
flg = 0
continue
else:
break
print("==========")
print(w)
print(b)
[[ 0. 0.]]
0
++++++
(3, 3)
[[3]
[3]]
[[ 0.]]
[[ 0.]]
[[ 3. 3.]] 1
++++++
(3, 3)
[[3]
[3]]
[[ 18.]]
[[ 19.]]
++++++
(4, 3)
[[4]
[3]]
[[ 21.]]
[[ 22.]]
++++++
(1, 1)
[[1]
[1]]
[[ 6.]]
[[-7.]]
[[ 2. 2.]] 0
++++++
(3, 3)
[[3]
[3]]
[[ 12.]]
[[ 12.]]
++++++
(4, 3)
[[4]
[3]]
[[ 14.]]
[[ 14.]]
++++++
(1, 1)
[[1]
[1]]
[[ 4.]]
[[-4.]]
[[ 1. 1.]] -1
++++++
(3, 3)
[[3]
[3]]
[[ 6.]]
[[ 5.]]
++++++
(4, 3)
[[4]
[3]]
[[ 7.]]
[[ 6.]]
++++++
(1, 1)
[[1]
[1]]
[[ 2.]]
[[-1.]]
[[ 0. 0.]] -2
++++++
(3, 3)
[[3]
[3]]
[[ 0.]]
[[-2.]]
[[ 3. 3.]] -1
++++++
(3, 3)
[[3]
[3]]
[[ 18.]]
[[ 17.]]
++++++
(4, 3)
[[4]
[3]]
[[ 21.]]
[[ 20.]]
++++++
(1, 1)
[[1]
[1]]
[[ 6.]]
[[-5.]]
[[ 2. 2.]] -2
++++++
(3, 3)
[[3]
[3]]
[[ 12.]]
[[ 10.]]
++++++
(4, 3)
[[4]
[3]]
[[ 14.]]
[[ 12.]]
++++++
(1, 1)
[[1]
[1]]
[[ 4.]]
[[-2.]]
[[ 1. 1.]] -3
++++++
(3, 3)
[[3]
[3]]
[[ 6.]]
[[ 3.]]
++++++
(4, 3)
[[4]
[3]]
[[ 7.]]
[[ 4.]]
++++++
(1, 1)
[[1]
[1]]
[[ 2.]]
[[ 1.]]
==========
[[ 1. 1.]]
-3