贝叶斯算法,简单地说就是比较测试数据是各个类别的概率,概率最大的就判断为这个数据的类别。具体来说就是一个条件概率的计算,测试数据是X,某个类别是C,那么也就是求P(C|X)最大的C,等于P(X|C)P(C)/P(X),因为P(X)不变,所以不必引入计算。也就是说,比较的是P(X|C)P(C)。进一步,P(C)=C类数量/训练集总数,其中训练集总数不变,所以也可以忽略。P(X|C)=P(X1 | C).P(X2 | C).......也就是C类中X各个特征的值的条件概率。
首先需要一个测试集,然后求出各个类的数量和各个类的各个特征的各个值的条件概率(三层数组),接下来就可以对测试数据进行分类,使用上面的数据进行计算。
在本实验中,使用的是汽车数据集,分为四类。我将汽车封装为一个类,然后将数据集内容存储在一个txt中(放在项目根目录下,叫做a.txt),每一行对应一个汽车数据元素,然后在主类test类中还提供了txt转成汽车数组的方法。根据实验要求取数据集的部分作为训练集,全集作为测试集。由于全集数据其实都是已分类的数据,所以每次分完类后可以直接判断本次分类是否正确,从而得出正确率。
1.汽车类:
public class Car {
private int kind,buying,maint,door,Persons,Lug_boot,Safety;
public int getKind() {
return kind;
}
public void setKind(int kind) {
this.kind = kind;
}
public int getBuying() {
return buying;
}
public void setBuying(int buying) {
this.buying = buying;
}
public int getMaint() {
return maint;
}
public void setMaint(int maint) {
this.maint = maint;
}
public int getDoor() {
return door;
}
public void setDoor(int door) {
this.door = door;
}
public int getPersons() {
return Persons;
}
public void setPersons(int persons) {
Persons = persons;
}
public int getLug_boot() {
return Lug_boot;
}
public void setLug_boot(int lug_boot) {
Lug_boot = lug_boot;
}
public int getSafety() {
return Safety;
}
public void setSafety(int safety) {
Safety = safety;
}
}
2.test类:
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
public class Test {
//将特征的各个值映射为Int,从0开始
public Car[] cars=new Car[1728];//数据集合
public int[] kindNumber=new int[4];//四个类别各自在训练集中的数目(为了简化计算直接使用数目)
public double[][][] kindP=new double[4][6][4];//四个类别的各个特点的各个值的条件概率
int numberAll=0;
int numberCotrrect=0;//总测试数和正确数
int numberTest=0;//测试集数目
public void getCars(String path){//读取数据文件转为cars
File file=new File(path);
try {
BufferedReader reader=new BufferedReader(new FileReader(file));
String s=null;
int number=0;
while ((s=reader.readLine())!=null) {
cars[number++]=txt2Car(s);
}
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
public Car txt2Car(String s){//一行文本转Car
Car car=new Car();
int start=0;
int end=0;
//1
while(!s.substring(end, end+1).equals(",")){
end++;
}
String buyings=s.substring(start,end);
int buying=0;
if (buyings.equals("vhigh")) {
buying=0;
} else {
if (buyings.equals("high")) {
buying=1;
} else {
if (buyings.equals("med")) {
buying=2;
} else {
buying=3;
}
}
}
car.setBuying(buying);
end++;
start=end;
//2
while(!s.substring(end, end+1).equals(",")){
end++;
}
String maints=s.substring(start,end);
int maint=0;
if (maints.equals("vhigh")) {
maint=0;
} else {
if (maints.equals("high")) {
maint=1;
} else {
if (maints.equals("med")) {
maint=2;
} else {
maint=3;
}
}
}
car.setMaint(maint);
end++;
start=end;
//3
while(!s.substring(end, end+1).equals(",")){
end++;
}
String doors=s.substring(start,end);
int door=0;
if (doors.equals("2")) {
door=0;
} else {
if (doors.equals("3")) {
door=1;
} else {
if (doors.equals("4")) {
door=2;
} else {
door=3;
}
}
}
car.setDoor(door);
end++;
start=end;
//4
while(!s.substring(end, end+1).equals(",")){
end++;
}
String persons=s.substring(start,end);
int person=0;
if (persons.equals("2")) {
person=0;
} else {
if (persons.equals("4")) {
person=1;
} else {
person=2;
}
}
car.setPersons(person);
end++;
start=end;
//5
while(!s.substring(end, end+1).equals(",")){
end++;
}
String lugs=s.substring(start,end);
int lug=0;
if (lugs.equals("small")) {
lug=0;
} else {
if (lugs.equals("med")) {
lug=1;
} else {
lug=2;
}
}
car.setLug_boot(lug);
end++;
start=end;
//6
while(!s.substring(end, end+1).equals(",")){
end++;
}
String safes=s.substring(start,end);
int safe=0;
if (safes.equals("low")) {
safe=0;
} else {
if (safes.equals("med")) {
safe=1;
} else {
safe=2;
}
}
car.setSafety(safe);
end++;
start=end;
//kind
String kinds=s.substring(start,s.length());
int kind=0;
if (kinds.equals("unacc")) {
kind=0;
} else {
if (kinds.equals("acc")) {
kind=1;
} else {
if (kinds.equals("good")) {
kind=2;
} else {
kind=3;
}
}
}
car.setKind(kind);
return car;
}
public void putTest(int n){//放入训练集进行训练,参数是训练集大小
numberTest=n;
numberCotrrect=0;
numberAll=0;
double [][][] k=new double[4][6][4];//四个类别的各个特点的各个值的个数
for(int i=0;i<n;i++){
Car car=cars[i];
int kind=car.getKind();
int buying=car.getKind();
int maint=car.getMaint();
int door=car.getDoor();
int persons=car.getPersons();
int lug_Boot=car.getLug_boot();
int safety=car.getSafety();
kindNumber[kind]++;
k[kind][0][buying]++;
k[kind][1][maint]++;
k[kind][2][door]++;
k[kind][3][persons]++;
k[kind][4][lug_Boot]++;
k[kind][5][safety]++;
}
// for(int i=0;i<4;i++){
// System.out.println("第"+i+"类的数目为"+kindNumber[i]);
// }
for(int i=0;i<4;i++){//为kindP赋值
for(int j=0;j<6;j++){
for(int p=0;p<4;p++){
if (kindNumber[i]>0) {
kindP[i][j][p]=k[i][j][p]/kindNumber[i];
//System.out.println("第"+i+"类第"+j+"个特征的第"+p+"个值的概率是"+kindP[i][j][p]);
}
}
}
}
}
public void work(){
for(int i=0;i<cars.length;i++){
numberAll++;
Car car=cars[i];
int kind=car.getKind();
int buying=car.getKind();
int maint=car.getMaint();
int door=car.getDoor();
int persons=car.getPersons();
int lug_Boot=car.getLug_boot();
int safety=car.getSafety();
int testKind=0;
double testP=0;
for(int j=0;j<4;j++){
double currentP=kindNumber[j]*kindP[j][0][buying]
*kindP[j][1][maint]*kindP[j][2][door]
*kindP[j][3][persons]*kindP[j][4][lug_Boot]*kindP[j][5][safety];
if (currentP>testP) {
testP=currentP;
testKind=j;
}
}
if (testKind==kind) {
numberCotrrect++;
}
}
double result=((double)(numberCotrrect*100))/numberAll;
System.out.println("测试集数:"+numberTest+" 测试正确数 :"+numberCotrrect+" 测试正确率:"+result+"%");
}
public static void main(String []args){
Test test=new Test();
test.getCars("a.txt");
test.putTest(100);
test.work();
test.putTest(200);
test.work();
test.putTest(500);
test.work();
test.putTest(700);
test.work();
test.putTest(1000);
test.work();
test.putTest(1350);
test.work();
}
}
可以看到main方法里每次测试都分两步,putTest指定全集的多少作为训练集(然后计算各个值),work遍历全集分类并统计结果。