贝叶斯(Baysian)分类器[1]是一种理论上比较简单的分类器。但是结合不同的网络结构和概率模形,它又可以演化成非常复杂的分类体系。本短文主要演示Baysian + Gaussian如何解两类问题。
其中,分母部分主要用于归一化。p(y)为先验概率(prior), p(x|y)为条件概率或称之为类概率密度(即已知x是哪一类的情况下p(x)的概率密度)。 在本文中,假设p(x|y)是高斯分布,即[2]:
而p(y)则采用伯努利(Bernoulli)分布[3]:
其中最大似然估计后得到的\eta即为第0类中训练样本的个数占总样本数的百分比。 求得五个参数后,就可能通过比较后验概率得到任意样本x的类别:
当f(x) 大于0时即表示
此时把样本x归为第0类,否则归为第1类。
下面通过Matlab程序进行演示:
训练的代码:
function [model_pos,model_neg ] = FindGuassianModel( x,y )
%FINDGUASSIANMODULE Summary of this function goes here
% Detailed explanation goes here
x_pos = x(:,y==1);
model_pos.mu = mean(x_pos,2);
model_pos.var = cov(x_pos');
model_pos.prior = length(x_pos)/length(x);
x_neg = x(:,y~=1);
model_neg.mu = mean(x_neg,2);
model_neg.var = cov(x_neg');
model_neg.prior = length(x_neg)/length(x);
end
计算分类误差:
function [err,h] = FindModelError(model_pos,model_neg, x,y )
%FINDGUASSIANMODULE Summary of this function goes here
% Detailed explanation goes here
mu1 = model_pos.mu;
sigma1 = model_pos.var;
p1 = model_pos.prior;
mu2 = model_neg.mu;
sigma2 = model_neg.var;
p2 = model_neg.prior;
bias = 0.5*log(det(sigma2))-0.5*log(det(sigma1))+log(p1/p2);
err = 0;
h = zeros(size(y));
for i=1:length(y)
c = bias + 0.5*(x(:,i)-mu2)'/sigma2*(x(:,i)-mu2) - 0.5*(x(:,i)-mu1)'/sigma1*(x(:,i)-mu1);
if c > 0
h(i) = 1;
else
h(i) = -1;
end
if h(i)~=y(i)
err = err + 1;
end
end
end
演示主程序:
%%
clc;
clear;
close all;
%% generate random data
shift =3.0;
n = 2;%2 dim
sigma = 1;
N = 500;
x = [randn(n,N/2)-shift, randn(n,N/2)*sigma+shift];
y = [ones(N/2,1);-ones(N/2,1)];
%show the data
figure;
plot(x(1,1:N/2),x(2,1:N/2),'rs');
hold on;
plot(x(1,1+N/2:N),x(2,1+N/2:N),'go');
title('2d training data');
legend('Positve samples','Negative samples','Location','SouthEast');
% model fitting using maximum likelihood
[model_pos,model_neg] = FindGuassianModel(x,y);
%% test on new dataset, same distribution
n = 2;%2 dim
%y = 1./exp(-w'*x+b)
sigma = 2;
N = 500;
x = [randn(n,N/2)-shift, randn(n,N/2)*sigma+shift];
y = [ones(N/2,1);-ones(N/2,1)];
figure;
plot(x(1,1:N/2),x(2,1:N/2),'rs');
hold on;
plot(x(1,1+N/2:N),x(2,1+N/2:N),'go');
title('2d testing data');
hold on;
%% gaussian model as a baseline
[err,h] = FindModelError(model_pos,model_neg,x,y);
fprintf('Baysian error on test data set: %f\n',err/N);
x_pos = x(:,h==1);
x_neg = x(:,h~=1);
plot(x_pos(1,:),x_pos(2,:),'r.');
hold on;
plot(x_neg(1,:),x_neg(2,:),'g.');
legend('Positve samples','Negative samples','Positve samples as predicted','Negative samples as predicted','Location','SouthEast');
最后的测试结果:
从测试结是上看,大部分样本都能分类正确(同色的点在同色的圆或方框中),只有0.8%的点分类错误。
本文的所有代码可在我的资源页http://download.youkuaiyun.com/detail/ranchlai/6018299下载
[1] http://en.wikipedia.org/wiki/Bayes_classifier
[2]http://en.wikipedia.org/wiki/Multivariate_normal_distribution