ID3就不介绍了,最终的决策树保存在了XML中,使用了Dom4J,注意如果要让Dom4J支持按XPath选择节点,还得引入包jaxen.jar。程序代码要求输入文件满足ARFF格式,并且属性都是标称变量。
003 | import
java.io.BufferedReader;
|
005 | import
java.io.FileReader;
|
006 | import
java.io.FileWriter;
|
007 | import
java.io.IOException;
|
008 | import
java.util.ArrayList;
|
009 | import
java.util.Iterator;
|
010 | import
java.util.LinkedList;
|
011 | import
java.util.List; |
012 | import
java.util.regex.Matcher;
|
013 | import
java.util.regex.Pattern;
|
015 | import
org.dom4j.Document;
|
016 | import
org.dom4j.DocumentHelper;
|
017 | import
org.dom4j.Element;
|
018 | import
org.dom4j.io.OutputFormat;
|
019 | import
org.dom4j.io.XMLWriter;
|
022 |
private ArrayList<String> attribute =
new
ArrayList<String>(); |
023 |
private ArrayList<ArrayList<String>> attributevalue =
new
ArrayList<ArrayList<String>>(); |
024 |
private ArrayList<String[]> data =
new
ArrayList<String[]>();; |
026 |
public static
final
String patternString = "@attribute(.*)[{](.*?)[}]" ;
|
032 |
xmldoc = DocumentHelper.createDocument(); |
033 |
root = xmldoc.addElement( "root" );
|
034 |
root.addElement( "DecisionTree" ).addAttribute( "value" ,
"null" );
|
037 |
public static
void
main(String[] args) { |
038 |
ID3 inst = new
ID3(); |
039 |
inst.readARFF( new
File( "/home/orisun/test/weather.nominal.arff" ));
|
041 |
LinkedList<Integer> ll= new
LinkedList<Integer>();
|
042 |
for ( int
i= 0 ;i<inst.attribute.size();i++){
|
046 |
ArrayList<Integer> al= new
ArrayList<Integer>();
|
047 |
for ( int
i= 0 ;i<inst.data.size();i++){
|
050 |
inst.buildDT( "DecisionTree" ,
"null" , al, ll);
|
051 |
inst.writeXML( "/home/orisun/test/dt.xml" );
|
056 |
public void
readARFF(File file) {
|
058 |
FileReader fr = new
FileReader(file); |
059 |
BufferedReader br = new
BufferedReader(fr);
|
061 |
Pattern pattern = Pattern.compile(patternString);
|
062 |
while
((line = br.readLine()) != null ) {
|
063 |
Matcher matcher = pattern.matcher(line); |
064 |
if (matcher.find()) {
|
065 |
attribute.add(matcher.group( 1 ).trim());
|
066 |
String[] values = matcher.group( 2 ).split( "," );
|
067 |
ArrayList<String> al = new
ArrayList<String>(values.length);
|
068 |
for (String value : values) {
|
069 |
al.add(value.trim()); |
071 |
attributevalue.add(al); |
072 |
} else
if
(line.startsWith( "@data" )) {
|
073 |
while
((line = br.readLine()) != null ) {
|
076 |
String[] row = line.split( "," );
|
084 |
} catch
(IOException e1) { |
085 |
e1.printStackTrace(); |
090 |
public void
setDec( int
n) { |
091 |
if (n <
0 || n >= attribute.size()) {
|
092 |
System.err.println( "决策变量指定错误。" );
|
097 |
public void
setDec(String name) {
|
098 |
int n = attribute.indexOf(name);
|
103 |
public double
getEntropy( int [] arr) {
|
104 |
double entropy =
0.0 ; |
106 |
for ( int
i = 0 ; i < arr.length; i++) {
|
107 |
entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log( 2 );
|
110 |
entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log( 2 );
|
116 |
public double
getEntropy( int [] arr,
int sum) {
|
117 |
double entropy =
0.0 ; |
118 |
for ( int
i = 0 ; i < arr.length; i++) {
|
119 |
entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log( 2 );
|
121 |
entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log( 2 );
|
126 |
public boolean
infoPure(ArrayList<Integer> subset) {
|
127 |
String value = data.get(subset.get( 0 ))[decatt];
|
128 |
for ( int
i = 1 ; i < subset.size(); i++) {
|
129 |
String next=data.get(subset.get(i))[decatt];
|
131 |
if (!value.equals(next))
|
138 |
public double
calNodeEntropy(ArrayList<Integer> subset,
int index) {
|
139 |
int sum = subset.size();
|
140 |
double entropy =
0.0 ; |
141 |
int [][] info =
new int [attributevalue.get(index).size()][];
|
142 |
for ( int
i = 0 ; i < info.length; i++)
|
143 |
info[i] = new
int [attributevalue.get(decatt).size()];
|
144 |
int [] count = new
int [attributevalue.get(index).size()];
|
145 |
for ( int
i = 0 ; i < sum; i++) {
|
146 |
int n = subset.get(i);
|
147 |
String nodevalue = data.get(n)[index]; |
148 |
int nodeind = attributevalue.get(index).indexOf(nodevalue);
|
150 |
String decvalue = data.get(n)[decatt]; |
151 |
int decind = attributevalue.get(decatt).indexOf(decvalue);
|
152 |
info[nodeind][decind]++; |
154 |
for ( int
i = 0 ; i < info.length; i++) {
|
155 |
entropy += getEntropy(info[i]) * count[i] / sum;
|
161 |
public void
buildDT(String name, String value, ArrayList<Integer> subset,
|
162 |
LinkedList<Integer> selatt) { |
164 |
@SuppressWarnings ( "unchecked" )
|
165 |
List<Element> list = root.selectNodes( "//" +name);
|
166 |
Iterator<Element> iter=list.iterator(); |
167 |
while (iter.hasNext()){
|
169 |
if (ele.attributeValue( "value" ).equals(value))
|
172 |
if (infoPure(subset)) {
|
173 |
ele.setText(data.get(subset.get( 0 ))[decatt]);
|
177 |
double minEntropy = Double.MAX_VALUE;
|
178 |
for ( int
i = 0 ; i < selatt.size(); i++) {
|
181 |
double entropy = calNodeEntropy(subset, selatt.get(i));
|
182 |
if (entropy < minEntropy) {
|
183 |
minIndex = selatt.get(i); |
184 |
minEntropy = entropy; |
187 |
String nodeName = attribute.get(minIndex); |
188 |
selatt.remove( new
Integer(minIndex));
|
189 |
ArrayList<String> attvalues = attributevalue.get(minIndex);
|
190 |
for (String val : attvalues) {
|
191 |
ele.addElement(nodeName).addAttribute( "value" , val);
|
192 |
ArrayList<Integer> al = new
ArrayList<Integer>();
|
193 |
for
( int
i = 0 ; i < subset.size(); i++) {
|
194 |
if (data.get(subset.get(i))[minIndex].equals(val)) {
|
195 |
al.add(subset.get(i)); |
198 |
buildDT(nodeName, val, al, selatt); |
203 |
public void
writeXML(String filename) {
|
205 |
File file = new
File(filename); |
207 |
file.createNewFile(); |
208 |
FileWriter fw = new
FileWriter(file); |
209 |
OutputFormat format = OutputFormat.createPrettyPrint();
|
210 |
XMLWriter output = new
XMLWriter(fw, format);
|
211 |
output.write(xmldoc); |
213 |
} catch
(IOException e) { |
214 |
System.out.println(e.getMessage()); |
最终生成的文件如下:
<? xml
version = "1.0"
encoding = "UTF-8" ?>
|
< DecisionTree
value = "null" >
|
< humidity
value = "high" >no</ humidity >
|
< humidity
value = "normal" >yes</ humidity >
|
< outlook
value = "overcast" >yes</ outlook >
|
< windy
value = "TRUE" >no</ windy >
|
< windy
value = "FALSE" >yes</ windy >
|