- (转)http://leon-a.javaeye.com/blog/178585
- package graph;
- import java.util.ArrayList;
- import java.util.List;
- import java.util.TreeSet;
- /**
- * 决策树的ID3算法
- * 参照实现http://www.blog.edu.cn/user2/huangbo929/archives/2006/1533249.shtml
- * @author Leon.Chen
- */
- public class DTree {
- /**
- * root
- */
- TreeNode root;
- /**
- * 可见性数组
- */
- private static boolean[] visable;
- private Object[] array;
- private int index;
- /**
- * @param args
- */
- @SuppressWarnings("boxing")
- public static void main(String[] args) {
- //初始数据
- Object[] array = new Object[] {
- new String[]{ "Sunny" ,"Hot" ,"High" ,"Weak" ,"No" },
- new String[]{ "Sunny" ,"Hot" ,"High" ,"Strong" ,"No" },
- new String[]{ "Overcast" ,"Hot" ,"High" ,"Weak" ,"Yes"},
- new String[]{ "Rain" ,"Mild" ,"High" ,"Weak" ,"Yes"},
- new String[]{ "Rain" ,"Cool" ,"Normal" ,"Weak" ,"Yes"},
- new String[]{ "Rain" ,"Cool" ,"Normal" ,"Strong" ,"No" },
- new String[]{ "Overcast" ,"Cool" ,"Normal" ,"Strong" ,"Yes"},
- new String[]{ "Sunny" ,"Mild" ,"High" ,"Weak" ,"No" },
- new String[]{ "Sunny" ,"Cool" ,"Normal" ,"Weak" ,"Yes"},
- new String[]{ "Rain" ,"Mild" ,"Normal" ,"Weak" ,"Yes"},
- new String[]{ "Sunny" ,"Mild" ,"Normal" ,"Strong" ,"Yes"},
- new String[]{ "Overcast" ,"Mild" ,"High" ,"Strong" ,"Yes"},
- new String[]{ "Overcast" ,"Hot" ,"Normal" ,"Weak" ,"Yes"},
- new String[]{ "Rain" ,"Mild" ,"High" ,"Strong" ,"No" },
- };
- DTree tree = new DTree();
- tree.create(array,4);
- }
- public void create(Object[] array,int index){
- this.array = array;
- init(array,index);
- createDTree(array);
- printDTree(root);
- }
- public Object[] getMaxGain(Object[] array){
- Object[] result = new Object[2];
- double gain = 0;
- int index = 0;
- for(int i=0;i<visable.length;i++){
- if(!visable[i]){
- double value = gain(array,i);
- if(gain < value){
- gain = value;
- index = i;
- }
- }
- }
- result[0] = gain;
- result[1] = index;
- visable[index] = true;
- return result;
- }
- public void createDTree(Object[] array) {
- Object[] maxgain = getMaxGain(array);
- if (root == null) {
- root = new TreeNode();
- root.parent = null;
- root.parentArrtibute = null;
- root.arrtibutes = getArrtibutes(((Integer) maxgain[1]).intValue());
- root.nodeName = getNodeName(((Integer) maxgain[1]).intValue());
- root.childNodes = new TreeNode[root.arrtibutes.length];
- insertTree(array,root);
- }
- }
- public void insertTree(Object[] array,TreeNode parentNode){
- String[] arrtibutes = parentNode.arrtibutes;
- for(int i=0;i<arrtibutes.length;i++){
- Object[] pickArray = pickUpAndCreateArray(array,arrtibutes[i],getNodeIndex(parentNode.nodeName));
- Object[] info = getMaxGain(pickArray);
- double gain = ((Double)info[0]).doubleValue();
- if(gain != 0){
- int index = ((Integer) info[1]).intValue();
- System.out.println("gain = "+gain+" ,node name = "+getNodeName(index));
- TreeNode currentNode = new TreeNode();
- currentNode.parent = parentNode;
- currentNode.parentArrtibute = arrtibutes[i];
- currentNode.arrtibutes = getArrtibutes(index);
- currentNode.nodeName = getNodeName(index);
- currentNode.childNodes = new TreeNode[currentNode.arrtibutes.length];
- parentNode.childNodes[i] = currentNode;
- insertTree(pickArray,currentNode);
- }else {
- TreeNode leafNode = new TreeNode();
- leafNode.parent = parentNode;
- leafNode.parentArrtibute = arrtibutes[i];
- leafNode.arrtibutes = new String[0];
- leafNode.nodeName = getLeafNodeName(pickArray);
- leafNode.childNodes = new TreeNode[0];
- parentNode.childNodes[i] = leafNode;
- }
- }
- }
- public void printDTree(TreeNode node){
- System.out.println(node.nodeName);
- TreeNode[] childs = node.childNodes;
- for(int i=0;i<childs.length;i++){
- if(childs[i]!=null){
- System.out.println(childs[i].parentArrtibute);
- printDTree(childs[i]);
- }
- }
- }
- /**
- * @param dataArray 原始数组 D
- * @param criterion 标准值
- * @return double
- */
- public void init(Object[] dataArray,int index) {
- this.index = index;
- //数据初始化
- visable = new boolean[((String[])dataArray[0]).length];
- for(int i=0;i<visable.length;i++) {
- if(i == index){
- visable[i] = true;
- }else {
- visable[i] = false;
- }
- }
- }
- public Object[] pickUpAndCreateArray(Object[] array,String arrtibute,int index){
- List<String[]> list = new ArrayList<String[]>();
- for(int i=0;i<array.length;i++){
- String[] strs = (String[])array[i];
- if(strs[index].equals(arrtibute)){
- list.add(strs);
- }
- }
- return list.toArray();
- }
- /**
- * Entropy(S)
- * @param array
- * @return double
- */
- public double gain(Object[] array,int index) {
- String[] playBalls = getArrtibutes(this.index);
- int[] counts = new int[playBalls.length];
- for(int i=0;i<counts.length;i++) {
- counts[i] = 0;
- }
- for(int i=0;i<array.length;i++) {
- String[] strs = (String[])array[i];
- for(int j=0;j<playBalls.length;j++) {
- if(strs[this.index].equals(playBalls[j])) {
- counts[j]++;
- }
- }
- }
- /**
- * Entropy(S) = S -p(I) log2 p(I)
- */
- double entropyS = 0;
- for(int i=0;i<counts.length;i++) {
- entropyS += DTreeUtil.sigma(counts[i],array.length);
- }
- String[] arrtibutes = getArrtibutes(index);
- /**
- * total ((|Sv| / |S|) * Entropy(Sv))
- */
- double sv_total = 0;
- for(int i=0;i<arrtibutes.length;i++){
- sv_total += entropySv(array, index,arrtibutes[i],array.length);
- }
- return entropyS-sv_total;
- }
- /**
- * ((|Sv| / |S|) * Entropy(Sv))
- * @param array
- * @param index
- * @param arrtibute
- * @param allTotal
- * @return
- */
- public double entropySv(Object[] array,int index,String arrtibute,int allTotal) {
- String[] playBalls = getArrtibutes(this.index);
- int[] counts = new int[playBalls.length];
- for(int i=0;i<counts.length;i++) {
- counts[i] = 0;
- }
- for (int i = 0; i < array.length; i++) {
- String[] strs = (String[]) array[i];
- if (strs[index].equals(arrtibute)) {
- for (int k = 0; k < playBalls.length; k++) {
- if (strs[this.index].equals(playBalls[k])) {
- counts[k]++;
- }
- }
- }
- }
- int total = 0;
- double entropySv = 0;
- for(int i=0;i<counts.length;i++){
- total += counts[i];
- }
- for(int i=0;i<counts.length;i++){
- entropySv += DTreeUtil.sigma(counts[i],total);
- }
- return DTreeUtil.getPi(total, allTotal)*entropySv;
- }
- @SuppressWarnings("unchecked")
- public String[] getArrtibutes(int index) {
- TreeSet<String> set = new TreeSet<String>(new SequenceComparator());
- for (int i = 0; i < array.length; i++) {
- String[] strs = (String[]) array[i];
- set.add(strs[index]);
- }
- String[] result = new String[set.size()];
- return set.toArray(result);
- }
- public String getNodeName(int index) {
- String[] strs = new String[]{"Outlook","Temperature","Humidity","Wind","Play ball"};
- for(int i=0;i<strs.length;i++){
- if(i == index){
- return strs[i];
- }
- }
- return null;
- }
- public String getLeafNodeName(Object[] array){
- if(array!=null && array.length>0){
- String[] strs = (String[])array[0];
- return strs[index];
- }
- return null;
- }
- public int getNodeIndex(String name) {
- String[] strs = new String[]{"Outlook","Temperature","Humidity","Wind","Play ball"};
- for(int i=0;i<strs.length;i++){
- if(name.equals(strs[i])){
- return i;
- }
- }
- return -1;
- }
- }
- package graph;
- /**
- * @author B.Chen
- */
- public class TreeNode {
- /**
- * 父
- */
- TreeNode parent;
- /**
- * 指向父的哪个属性
- */
- String parentArrtibute;
- /**
- * 节点名
- */
- String nodeName;
- /**
- * 属性数组
- */
- String[] arrtibutes;
- /**
- * 节点数组
- */
- TreeNode[] childNodes;
- }
- package graph;
- public class DTreeUtil {
- /**
- * 属性值熵的计算 Info(T)=(i=1...k)pi*log(2)pi
- *
- * @param x
- * @param total
- * @return double
- */
- public static double sigma(int x, int total) {
- if(x == 0){
- return 0;
- }
- double x_pi = getPi(x, total);
- return -(x_pi * logYBase2(x_pi));
- }
- /**
- * log2y
- *
- * @param y
- * @return double
- */
- public static double logYBase2(double y) {
- return Math.log(y) / Math.log(2);
- }
- /**
- * pi是当前这个属性出现的概率(=出现次数/总数)
- *
- * @param x
- * @param total
- * @return double
- */
- public static double getPi(int x, int total) {
- return x * Double.parseDouble("1.0") / total;
- }
- }
- package graph;
- import java.util.Comparator;
- public class SequenceComparator implements Comparator {
- public int compare(Object o1, Object o2) throws ClassCastException {
- String str1 = (String) o1;
- String str2 = (String) o2;
- return str1.compareTo(str2);
- }
- }
决策树ID3算法
最新推荐文章于 2024-08-22 03:00:00 发布
1504

被折叠的 条评论
为什么被折叠?



