1.id3算法介绍
ID3算法是一种贪心算法,用来构造决策树。ID3算法起源于概念学习系统(CLS),以信息熵的下降速度为选取测试属性的标准,即在每个节点选取还尚未被用来划分的具有最高信息增益的属性作为划分标准,然后继续这个过程,直到生成的决策树能完美分类训练样例。
2.优点
* ID3算法避免了搜索不完整假设空间的一个主要风险:假设空间可能不包含目标函数。
*ID3算法在搜索的每一步都使用当前的所有训练样例,大大降低了对个别训练样例错误的敏感性。
*ID3算法在搜索过程中不进行回溯。所以,它易受无回溯的爬山搜索中的常见风险影响:收敛到局部最优而不是全局最优。
*ID3算法只能处理离散值的属性。信息增益度量存在一个内在偏置,它偏袒具有较多值的属性。
*ID3算法增长树的每一个分支的深度,直到恰好能对训练样例完美地分类,存在决策树过度拟合。
3.具体实现
1.测试数据
1.数据一
色泽,根蒂,敲声,纹理,脐部,触感,好瓜
青绿,蜷缩,浊响,清晰,凹陷,硬滑,好瓜
乌黑,蜷缩,沉闷,清晰,凹陷,硬滑,好瓜
乌黑,蜷缩,浊响,清晰,凹陷,硬滑,好瓜
青绿,蜷缩,沉闷,清晰,凹陷,硬滑,好瓜
浅白,蜷缩,浊响,清晰,凹陷,硬滑,好瓜
青绿,稍蜷,浊响,清晰,稍凹,软粘,好瓜
乌黑,稍蜷,浊响,稍糊,稍凹,软粘,好瓜
乌黑,稍蜷,浊响,清晰,稍凹,硬滑,好瓜
乌黑,稍蜷,沉闷,稍糊,稍凹,硬滑,坏瓜
青绿,硬挺,清脆,清晰,平坦,软粘,坏瓜
浅白,硬挺,清脆,模糊,平坦,硬滑,坏瓜
浅白,蜷缩,浊响,模糊,平坦,软粘,坏瓜
青绿,稍蜷,浊响,稍糊,凹陷,硬滑,坏瓜
浅白,稍蜷,沉闷,稍糊,凹陷,硬滑,坏瓜
乌黑,稍蜷,浊响,清晰,稍凹,软粘,坏瓜
浅白,蜷缩,浊响,模糊,平坦,硬滑,坏瓜
青绿,蜷缩,沉闷,稍糊,稍凹,硬滑,坏瓜
2.数据二
天气,温度,湿度,风况,适宜
晴天,高温,中湿,无风,不宜
晴天,高温,中湿,有风,不宜
多云,高温,低湿,无风,适宜
雨天,低温,高湿,无风,适宜
雨天,低温,低湿,无风,适宜
雨天,低温,低湿,有风,不宜
多云,低温,低湿,有风,适宜
晴天,中温,高湿,无风,不宜
晴天,低温,低湿,无风,适宜
雨天,中温,低湿,无风,适宜
晴天,中温,低湿,有风,适宜
多云,中温,中湿,有风,适宜
多云,高温,低湿,无风,适宜
雨天,中温,低湿,有风,不宜
2.代码实现
ID3.java
package package1;
package package1;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
class treeNode{//树节点
private String sname;//节点名
public treeNode(String str) {
sname=str;
}
public String getsname() {
return sname;
}
ArrayList<String> label=new ArrayList<String>();//和子节点间的边标签
ArrayList<treeNode> node=new ArrayList<treeNode>();//对应子节点
}
public class ID3 {
private ArrayList<String> label=new ArrayList<String>();//特征标签
private ArrayList<ArrayList<String>> date=new ArrayList<ArrayList<String>>();//数据集
private ArrayList<ArrayList<String>> test=new ArrayList<ArrayList<String>>();//测试数据集
private ArrayList<String> sum=new ArrayList<String>();//分类种类数
private String kind;
public ID3(String path,String path0) throws FileNotFoundException {
//初始化训练数据并得到分类种数
getDate(path);
//获取测试数据集
gettestDate(path0);
init(date);
}
public void init(ArrayList<ArrayList<String>> date) {
//得到种类数
sum.add(date.get(0).get(date.get(0).size()-1));
for(int i=0;i<date.size();i++) {
if(sum.contains(date.get(i).get(date.get(0).size()-1))==false) {
sum.add(date.get(i).get(date.get(0).size()-1));
}
}
}
//获取测试数据集
public void gettestDate(String path) throws FileNotFoundException {
String str;
int i=0;
try {
//BufferedReader in=new BufferedReader(new FileReader(path));
FileInputStream fis = new FileInputStream(path);
InputStreamReader isr = new InputStreamReader(fis, "GB2312");
BufferedReader in = new BufferedReader(isr);
while((str=in.readLine())!=null) {
String[] strs=str.split(",");
ArrayList<String> line =new ArrayList<String>();
for(int j=0;j<strs.length;j++) {
line.add(strs[j]);
//System.out.print(strs[j]+" ");
}
test.add(line);
//System.out.println();
i++;
}
in.close();
}catch(Exception e) {
e.printStackTrace();
}
}
//获取训练数据集
public void getDate(String path) throws FileNotFoundException {
String str;
int i=0;
try {
//BufferedReader in=new BufferedReader(new FileReader(path));
FileInputStream fis = new FileInputStream(path);
InputStreamReader isr = new InputStreamReader(fis, "GB2312");
BufferedReader in = new BufferedReader(isr);
while((str=in.readLine())!=null) {
if(i==0) {
String[] strs=str.split(",");
for(int j=0;j<strs.length;j++) {
label.add(strs[j]);
//System.out.print(strs[j]+" ");
}
i++;
//System.out.println();
continue;
}
String[] strs=str.split(",");
ArrayList<String> line =new ArrayList<String>();
for(int j=0;j<strs.length;j++) {
line.add(strs[j]);
//System.out.print(strs[j]+" ");
}
date.add(line);
//System.out.println();
i++;
}
in.close();
}catch(Exception e) {
e.printStackTrace();
}
}
public double Ent(ArrayList<ArrayList<String>> dat) {
//计算总的信息熵
int all=0;
double amount=0.0;
for(int i=0;i<sum.size();i++) {
for(int j=0;j<dat.size();j++) {
if(sum.get(i).equals(dat.get(j).get(dat.get(0).size()-1))) { all++;
}
}
if((double)all/dat.size()==0.0) {
continue;
}
//计算信息熵
amount+=((double)all/dat.size())*(Math.log(((double)all/dat.size()))/Math.log(2.0));
all=0;
}
if(amount==0.0) {
return 0.0;
}
return -amount;//计算信息熵
}
//计算条件熵并返回信息增益值
public double condtion(int a,ArrayList<ArrayList<String>> dat) {
ArrayList<String> all=new ArrayList<String>();
double c=0.0;
all.add(dat.get(0).get(a));
//得到属性种类
for(int i=0;i<dat.size();i++) {
if(all.contains(dat.get(i).get(a))==false) {
all.add(dat.get(i).get(a));
}
}
ArrayList<ArrayList<String>> plus=new ArrayList<ArrayList<String>>();
//部分分组
ArrayList<ArrayList<ArrayList<String>>> count=new ArrayList<ArrayList<ArrayList<String>>>();
//分组总和
for(int i=0;i<all.size();i++) {
for(int j=0;j<dat.size();j++) {
if(true==all.get(i).equals(dat.get(j).get(a))) {
plus.add(dat.get(j));
}
}
count.add(plus);
c+=((double)count.get(i).size()/dat.size())*Ent(count.get(i));
plus.removeAll(plus);
}
return (Ent(dat)-c);
//返回条件熵
}
//计算信息增益最大属性
public int Gain(ArrayList<ArrayList<String>> dat) {
ArrayList<Double> num=new ArrayList<Double>();
//保存各信息增益值
for(int i=0;i<dat.get(0).size()-1;i++) {
num.add(condtion(i,dat));
}
int index=0;
double max=num.get(0);
for(int i=1;i<num.size();i++) {
if(max<num.get(i)) {
max=num.get(i);
index=i;
}
}
//System.out.println("<"+label.get(index)+">");
return index;
}
//构建决策树
public treeNode creattree(ArrayList<ArrayList<String>> dat) {
int index=Gain(dat);
treeNode node=new treeNode(label.get(index));
ArrayList<String> s=new ArrayList<String>();//属性种类
s.add(dat.get(0).get(index));
//System.out.println(dat.get(0).get(index));
for(int i=1;i<dat.size();i++) {
if(s.contains(dat.get(i).get(index))==false) {
s.add(dat.get(i).get(index));
//System.out.println(dat.get(i).get(index));
}
}
ArrayList<ArrayList<String>> plus=new ArrayList<ArrayList<String>>();
//部分分组
ArrayList<ArrayList<ArrayList<String>>> count=new ArrayList<ArrayList<ArrayList<String>>>();
//分组总和
//得到节点下的边标签并分组
for(int i=0;i<s.size();i++) {
node.label.add(s.get(i));//添加边标签
//System.out.print("添加边标签:"+s.get(i)+" ");
for(int j=0;j<dat.size();j++) {
if(true==s.get(i).equals(dat.get(j).get(index))) {
plus.add(dat.get(j));
}
}
count.add(plus);
//System.out.println();
//以下添加结点
int k;
String str=count.get(i).get(0).get(count.get(i).get(0).size()-1);
for(k=1;k<count.get(i).size();k++) {
if(false==str.equals(count.get(i).get(k).get(count.get(i).get(k).size()-1))) {
break;
}
}
if(k==count.get(i).size()) {
treeNode dd=new treeNode(str);
node.node.add(dd);
//System.out.println("这是末端:"+str);
}
else {
//System.out.print("寻找新节点:");
node.node.add(creattree(count.get(i)));
}
plus.removeAll(plus);
}
return node;
}
//输出决策树
public void print(ArrayList<ArrayList<String>> dat) {
System.out.println("构建的决策树如下:\n ");
treeNode node=null;
node=creattree(dat);//类
put(node);//递归调用
}
//用于递归的函数
public void put(treeNode node) {
System.out.println("结点:"+node.getsname()+"\n");
for(int i=0;i<node.label.size();i++) {
System.out.println(node.getsname()+"的标签属性:"+node.label.get(i));
if(node.node.get(i).node.isEmpty()==true) {
System.out.println("叶子结点:"+node.node.get(i).getsname());
}
else {
put(node.node.get(i));
}
}
}
//用于对待决策数据进行预测并将结果保存在指定路径
public void testdate(ArrayList<ArrayList<String>> test,String path) throws IOException {
treeNode node=null;
int count=0;
node=creattree(this.date);//类
try {
BufferedWriter out=new BufferedWriter(new FileWriter(path));
for(int i=0;i<test.size();i++) {
testput(node,test.get(i));//递归调用
//System.out.println(kind);
for(int j=0;j<test.get(i).size();j++) {
out.write(test.get(i).get(j)+",");
}
if(kind.equals(date.get(i).get(date.get(i).size()-1))==true) {
count++;
}
out.write(kind);
out.newLine();
}
System.out.println("该次分类结果正确率为:"+(double)count/test.size()*100+"%");
out.flush();
out.close();
}catch(IOException e) {
e.printStackTrace();
}
}
//用于测试的递归调用
public void testput(treeNode node,ArrayList<String> t) {
int index=0;
for(int i=0;i<this.label.size();i++) {
if(this.label.get(i).equals(node.getsname())==true) {
index=i;
break;
}
}
for(int i=0;i<node.label.size();i++) {
if(t.get(index).equals(node.label.get(i))==false) {
continue;
}
if(node.node.get(i).node.isEmpty()==true) {
//System.out.println("分类结果为:"+node.node.get(i).getsname());
this.kind=node.node.get(i).getsname();//取出分类结果
}
else {
testput(node.node.get(i),t);
}
}
}
public static void main(String[] args) throws IOException {
String data=System.getProperty("user.dir") + "\\Data\\data1.txt";//训练数据集
String test=System.getProperty("user.dir") + "\\Data\\test.txt";//测试数据集
String result=System.getProperty("user.dir") + "\\Data\\result.txt";//预测结果集
ID3 id=new ID3(data,test);//初始化数据
id.print(id.date);//构建并输出决策树
//id.testdate(id.test,result);//预测数据并输出结果
}
}
本文深入解析ID3算法,一种用于构建决策树的贪心算法,探讨其在概念学习系统中的应用,以及如何通过信息增益选择最佳属性进行数据划分。文章详细介绍了ID3算法的优点、局限性,并提供了具体实现代码和示例。

被折叠的 条评论
为什么被折叠?



