FPGrowth的java实现

本文介绍了FPGrowth算法的Java实现,包括公共类`Aprioris`、条件模式基类`ConditionalPatternBase`和树节点类`TreeNode`。通过`getFrequentItemsets`方法获取频繁项集,使用`FPTree`构建频繁模式增长树,进而找出所有条件模式基,最终生成关联规则。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1、公共类

package com.apriori.common;


import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Map.Entry;


/**
 * <p>本类描述: 公共类</p>
 * <p>其他说明: </p>
 * @author Wang Haiyang
 * @date 2015-6-23 下午01:42:01
 */
public class Aprioris {


    /**
     * 方法描述:得到频繁1项集
     * @param D:事务数据库
     * @param min_sup:最小支持度阀值
     * @return
     */
    public static List<ArrayList<Integer>> getFrequent1Itemsets(List<ArrayList<Integer>> D, Integer min_sup, Map<ArrayList<Integer>, Integer> L) {
        List<ArrayList<Integer>> results = new ArrayList<ArrayList<Integer>>();
        Map<Integer, Integer> map = new HashMap<Integer, Integer>();
        for (ArrayList<Integer> d : D) {
            for (Integer g : d) {
                if (map.containsKey(g)) {
                    map.put(g, map.get(g) + 1);
                } else {
                    map.put(g, 1);
                }
            }
        }
        Set<Entry<Integer, Integer>> entrySet = map.entrySet();
        for (Entry<Integer, Integer> entry : entrySet) {
            if (entry.getValue() >= min_sup) {
                ArrayList<Integer> l = new ArrayList<Integer>();
                l.add(entry.getKey());
                results.add(l);
                L.put(l, entry.getValue());
            }
        }
        return results;
    }
    
    public static void displayAssociationRules(Map<String, Double> rules) {
        for (Entry<String, Double> entry : rules.entrySet()) {
            System.out.println(entry.getKey() + ":" + entry.getValue());
        }
    }
    
    /**
     * 方法描述:遍历频繁项集
     * @param L
     */
    public static void displayFrequentItemsets(Map<ArrayList<Integer>, Integer> L) {
        for (Entry<ArrayList<Integer>, Integer> entry : L.entrySet()) {
            System.out.print("(");
            for (Integer integer : entry.getKey()) {
                System.out.print(integer);
                System.out.print(",");
            }
            System.out.print(")");
            System.out.println();
        }
    }
    
    /**
     * 方法描述:产生关联规则
     * @param L
     * @param min_con
     * @return
     */
    public static Map<String, Double> produceAssociationRules(Map<ArrayList<Integer>, Integer> L, Double min_con) {
        Map<String, Double> result = new HashMap<String, Double>();
        for (Entry<ArrayList<Integer>, Integer> entry : L.entrySet()) {
            ArrayList<Integer> v = entry.getKey();
            if (v.size() > 1) {
                List<ArrayList<Integer>> lists = subList(v); // 得到给定list的所有非空真子集
                for (ArrayList<Integer> list : lists) {
                        List<Integer> exp = exceptList(v, list); // 得到除了list之外的子集
                        Integer integer1 = entry.getValue();
                        Integer integer2 = L.get(list);
                        if (integer1 != null && integer2 != null) {
                            Double per = Double.parseDouble(integer1 + "") / integer2;
                            if (per >= min_con) {
                                result.put(list.toString() + "=>" + exp.toString(), per);
                            }
                        }
                }
            }
        }
        return result;
    }
    
    /**
     * 方法描述:得到除了list之外的子集
     * @param key
     * @param list
     * @return
     */
    private static List<Integer> exceptList(ArrayList<Integer> key, ArrayList<Integer> list) {
        List<Integer> results = new ArrayList<Integer>();
        for (Integer l : key) {
            if (!list.contains(l)) {
                results.add(l);
            }
        }
        return results;
    }


    /**
     * 方法描述:得到给定list的所有非空真子集
     * @param key
     * @return
     */
    private static List<ArrayList<Integer>> subList(ArrayList<Integer> key) {
        List<ArrayList<Integer>> results = new ArrayList<ArrayList<Integer>>();
        for (int i = 0; i < key.size(); i++) {
            ArrayList<Integer> l = new ArrayList<Integer>();
            l.add(key.get(i));
            results.add(l);
        }
        
        for (int i = 0; i < key.size(); i++) {
            int keyi = key.get(i);
            for (int j = i + 1; j < key.size(); j++) {
                int keyj = key.get(j);
                ArrayList<Integer> l = new ArrayList<Integer>();
                l.add(keyi);
                l.add(keyj);
                Collections.sort(l);
                if (!l.containsAll(key)) {
                    if (!results.containsAll(l)) {
                        results.add(l);
                    }
                }
            }
        }
        return results;
    }
}


2、ConditionalPatternBase.java

package com.apriori.fpgrowth;


import java.util.ArrayList;
import java.util.List;


/**
 * <p>本类描述: t条件模式基</p>
 * <p>其他说明: </p>
 * @author Wang Haiyang
 * @date 2015-6-19 下午05:01:40
 */
public class ConditionalPatternBase {
    
    /**每个条件模式基*/
    private List<Integer> base = new ArrayList<Integer>();
   
    /**每个条件模式基的值*/
    private Integer value;


    public List<Integer> getBase() {
        return base;
    }


    public void setBase(List<Integer> base) {
        this.base = base;
    }


    public Integer getValue() {
        return value;
    }


    public void setValue(Integer value) {
        this.value = value;
    }
}


3、TreeNode

package com.apriori.fpgrowth;


import java.util.ArrayList;
import java.util.List;


public class TreeNode implements Comparable<TreeNode>{


    /**节点名字*/
    private Integer name;
    
    /**节点的出现次数*/
    private Integer value = 0;
    
    /**节点的孩子*/
    private List<TreeNode> child = new ArrayList<TreeNode>();
    
    /**节点的父亲*/
    private TreeNode parent;
    
    @Override
    public int compareTo(TreeNode o) {
        return o.getValue() - this.value;
    }
    
    public List<TreeNode> getChild() {
        return child;
    }


    public void setChild(List<TreeNode> child) {
        this.child = child;
    }


    public TreeNode getParent() {
        return parent;
    }


    public void setParent(TreeNode parent) {
        this.parent = parent;
    }


    public Integer getName() {
        return name;
    }


    public void setName(Integer name) {
        this.name = name;
    }


    public Integer getValue() {
        return value;
    }


    public void setValue(Integer value) {
        this.value = value;
    }


}


4、FPGrowth


package com.apriori.fpgrowth;


import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;


import com.apriori.common.Aprioris;


/**
 * <p>
 *     本类描述:
 *         本类主要完成找出频繁项集
 *  </p>
 * <p>
 *     主要步骤:
 *         1. 找出频繁1项集,并按照支持度递减排列
 *         2. 遍历项集D,按照频繁1项集的顺序构造频繁模式增长树
 *         3. 找出条件模式基
 *         4. 根据条件模式基找出频繁模式
 * </p>
 * @author Wang Haiyang
 * @date 2015-6-19 上午10:26:53
 */
public class FPGrowth {


    /**
     * 方法描述: 得到频繁模式集
     * @param D
     * @param min_sup
     * @return
     */
    public static Map<ArrayList<Integer>, Integer> getFrequentItemsets(List<ArrayList<Integer>> D, Integer min_sup) {
        Map<ArrayList<Integer>, Integer> L = new HashMap<ArrayList<Integer>, Integer>();
        Aprioris.getFrequent1Itemsets(D, min_sup, L); // 得到频繁1项集
        ArrayList<TreeNode> L1 = sortDes(L); // 降序排序频繁1项集
        TreeNode root = createFPTree(D, L1); // 得到频繁模式增长树
        getFrequentItemsetsByFPTree(D, min_sup, root, L, L1); // 得到频繁模式
        return L;
    }


    /**
     * 方法描述:得到频繁模式
     * @param D
     * @param min_sup
     * @param root
     * @param L
     */
    private static void getFrequentItemsetsByFPTree(List<ArrayList<Integer>> D, Integer min_sup, TreeNode root,
            Map<ArrayList<Integer>, Integer> L, ArrayList<TreeNode> L1) {
        List<TreeNode> nodes = getLeafs(root); // 得到树的叶子节点
        Map<Integer, ArrayList<ConditionalPatternBase>> map = getAllConditionalPatternBases(nodes); // 得到所有的条件模式基
        // 得到所有的频繁模式
       
        for (Entry<Integer, ArrayList<ConditionalPatternBase>> entry : map.entrySet()) { // 得到组合条件模式基
            ArrayList<ConditionalPatternBase> value = entry.getValue();
            TreeNode t = createFPTree(value, L1); // 得到组合条件模式基树
            List<TreeNode> n = new ArrayList<TreeNode>();
            getNodes(t, min_sup, n); // 得到满足min_sup的节点
            
            for (int i = 0; i < n.size(); i++) {
                ArrayList<Integer> l = new ArrayList<Integer>();
                l.add(n.get(i).getName());
                l.add(entry.getKey());
                L.put(l, n.get(i).getValue());
            }
            
            for (int i = 0; i < n.size(); i++) {
                int keyi = n.get(i).getName();
                for (int j = i + 1; j < n.size(); j++) {
                    int keyj = n.get(j).getName();
                    ArrayList<Integer> l = new ArrayList<Integer>();
                    l.add(keyi);
                    l.add(keyj);
                    l.add(entry.getKey());
                    L.put(l, n.get(j).getValue());
                }
            }
        }
        
        
    }


    /**
     * 方法描述:得到满足min_sup的节点
     * @param node
     * @param min_sup
     * @param results
     */
    private static void getNodes(TreeNode node, Integer min_sup, List<TreeNode> results) {
        List<TreeNode> childs = node.getChild();
        if (childs == null || childs.size() == 0) {
            return;
        } else {
            for (TreeNode child : childs) {
                if (child.getValue() >= min_sup) {
                    results.add(child);
                }
                getNodes(child, min_sup, results);
            }
        }
        return;
    }


    /**
     * 方法描述:得到所有的条件模式基
     * @param nodes
     * @return
     */
    private static Map<Integer, ArrayList<ConditionalPatternBase>> getAllConditionalPatternBases(List<TreeNode> nodes) {
        Map<Integer, ArrayList<ConditionalPatternBase>> results = new HashMap<Integer, ArrayList<ConditionalPatternBase>>();
        for (TreeNode leaf : nodes) {
            ConditionalPatternBase base = new ConditionalPatternBase();
            TreeNode parent = leaf.getParent();
            base.setValue(leaf.getValue());
            List<Integer> ins = new ArrayList<Integer>();
            while (parent != null && parent.getName() != null) {
                ins.add(parent.getName());
                parent = parent.getParent();
            }
            Collections.reverse(ins);
            base.setBase(ins);
            if (results.containsKey(leaf.getName())) {
                results.get(leaf.getName()).add(base);
            } else {
                ArrayList<ConditionalPatternBase> lists = new ArrayList<ConditionalPatternBase>();
                lists.add(base);
                results.put(leaf.getName(), lists);
            }
        }
        return results;
    }


    /**
     * 方法描述:得到指定树的所有叶子节点
     * @param root
     * @return
     */
    private static List<TreeNode> getLeafs(TreeNode root) {
        List<TreeNode> results = new ArrayList<TreeNode>();
        traverseTree(root, results);
        return results;
    }


    /**
     * 方法描述:递归遍整个树
     * @param node
     */
    private static void traverseTree(TreeNode node, List<TreeNode> results) {
        List<TreeNode> childs = node.getChild();
        if (childs == null || childs.size() == 0) {
            results.add(node);
        } else {
            for (TreeNode child : childs) {
                traverseTree(child, results);
            }
        }
    }


    /**
     * 方法描述: 得到频繁模式增长树
     * @param D
     * @param L1
     * @return
     */
    private static TreeNode createFPTree(List<ArrayList<Integer>> D, ArrayList<TreeNode> L1) {
        TreeNode root = new TreeNode();
        for (ArrayList<Integer> lists : D) {
            int flag = 0;
            for (TreeNode node : L1) { // 针对lists,按照L1的顺序排序
                if(lists.contains(node.getName())) {
                    int index = lists.indexOf(node.getName());
                    swap(lists, index, flag);
                    flag++;
                }
            }
            
            TreeNode node = root;
            for (Integer element : lists) { // 将lists放到result(即tree中)
                if(containsValue(node.getChild(), element)) {
                    int index = getIndexOf(node.getChild(), element);
                    node.getChild().get(index).setValue(node.getChild().get(index).getValue() + 1);
                    node.getChild().get(index).setParent(node);
                    node = node.getChild().get(index);
                } else {
                    TreeNode n = new TreeNode();
                    n.setName(element);
                    n.setValue(1);
                    node.getChild().add(n);
                    n.setParent(node);
                    node = n;
                }
            }
        }
        return root;
    }
    
    private static TreeNode createFPTree(ArrayList<ConditionalPatternBase> value, ArrayList<TreeNode> L1) {
        TreeNode root = new TreeNode();
        for (ConditionalPatternBase c : value) {
           ArrayList<Integer> lists = (ArrayList<Integer>)c.getBase();
           int v = c.getValue();
            int flag = 0;
            for (TreeNode node : L1) { // 针对lists,按照L1的顺序排序
                if(lists.contains(node.getName())) {
                    int index = lists.indexOf(node.getName());
                    swap(lists, index, flag);
                    flag++;
                }
            }
            
            TreeNode node = root;
            for (Integer element : lists) { // 将lists放到result(即tree中)
                if(containsValue(node.getChild(), element)) {
                    int index = getIndexOf(node.getChild(), element);
                    node.getChild().get(index).setValue(node.getChild().get(index).getValue() + v);
                    node.getChild().get(index).setParent(node);
                    node = node.getChild().get(index);
                } else {
                    TreeNode n = new TreeNode();
                    n.setName(element);
                    n.setValue(v);
                    node.getChild().add(n);
                    n.setParent(node);
                    node = n;
                }
            }
        }
        return root;
    }
    
    /**
     * 方法描述: 交换
     * @param lists
     * @param index
     * @param flag
     */
    private static void swap(ArrayList<Integer> lists, int index, int flag) {
        int temp = lists.get(index);
        lists.set(index, lists.get(flag));
        lists.set(flag, temp);
    }


    /**
     * 方法描述:按照出现次数降序排序频繁1项集
     * @param L
     * @return
     */
    private static ArrayList<TreeNode> sortDes(Map<ArrayList<Integer>, Integer> L) {
        ArrayList<TreeNode> results = new ArrayList<TreeNode>();
        for (Entry<ArrayList<Integer>, Integer> enttry : L.entrySet()) {
            TreeNode node = new TreeNode();
            node.setName(enttry.getKey().get(0));
            node.setValue(enttry.getValue());
            results.add(node);
        }
        Collections.sort(results);
        return results;
    }
    
    private static int getIndexOf(List<TreeNode> child, Integer element) {
        for (int i = 0; i < child.size(); i++) {
            if(child.get(i).getName() == element) {
                return i;
            }
        }
        return 0;
    }


    private static boolean containsValue(List<TreeNode> child, Integer element) {
        for (TreeNode node : child) {
            if(node.getName() == element) {
                return true;
            }
        }
        return false;
    }
    
    public static void main(String[] args) {
        List<ArrayList<Integer>> D = new ArrayList<ArrayList<Integer>>();
        ArrayList<Integer> list1 = new ArrayList<Integer>();
        list1.add(1);
        list1.add(2);
        list1.add(5);
        ArrayList<Integer> list2 = new ArrayList<Integer>();
        list2.add(2);
        list2.add(4);
        ArrayList<Integer> list3 = new ArrayList<Integer>();
        list3.add(2);
        list3.add(3);
        ArrayList<Integer> list4 = new ArrayList<Integer>();
        list4.add(1);
        list4.add(2);
        list4.add(4);
        ArrayList<Integer> list5 = new ArrayList<Integer>();
        list5.add(1);
        list5.add(3);
        ArrayList<Integer> list6 = new ArrayList<Integer>();
        list6.add(2);
        list6.add(3);
        ArrayList<Integer> list7 = new ArrayList<Integer>();
        list7.add(1);
        list7.add(3);
        ArrayList<Integer> list8 = new ArrayList<Integer>();
        list8.add(1);
        list8.add(2);
        list8.add(3);
        list8.add(5);
        ArrayList<Integer> list9 = new ArrayList<Integer>();
        list9.add(1);
        list9.add(2);
        list9.add(3);
        D.add(list1);
        D.add(list2);
        D.add(list3);
        D.add(list4);
        D.add(list5);
        D.add(list6);
        D.add(list7);
        D.add(list8);
        D.add(list9);
        Integer min_sup = 2;
        Double min_con = 0.7;
        Map<ArrayList<Integer>, Integer> L = getFrequentItemsets(D, min_sup);
        Aprioris.displayFrequentItemsets(L); // 打印频繁项集
        Map<String, Double> rules = Aprioris.produceAssociationRules(L, min_con); // 产生关联规则
        Aprioris.displayAssociationRules(rules); // 打印关联规则
    }
}

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值