1. 算法原理
FP-Tree相对于Apriori算法,减少了I/O的次数,原理是先找到原数据的频繁1项集,即项头表。得到后按照项头表的sup值给初始表排序。并且创建树形结构,每个节点存节点名称和出现次数。将初始表迭代放入树中,建树过程完成。挖掘过程是倒序遍历项头表,对于每个s,寻找s在树中到根的路径,组合其余分支的s, 父节点的sup值为所有s节点的sup值之和。得到频繁项集。最终求出最大频繁项集即可
2.代码实现
package com.clxk1997;
/**
* @Description 单个数据节点 name->cnt
* @Author Clxk
* @Date 2019/4/15 20:53
* @Version 1.0
*/
public class Data implements Comparable{
private String name;
private int cnt;
public Data() {
}
public Data(String name, int cnt) {
this.name = name;
this.cnt = cnt;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public int getCnt() {
return cnt;
}
public void setCnt(int cnt) {
this.cnt = cnt;
}
@Override
public int compareTo(Object o) {
if(this.cnt > ((Data)o).getCnt()) return 1;
return 0;
}
}
package com.clxk1997;
import java.util.ArrayList;
import java.util.List;
/**
* @Description FP-Tree树
* @Author Clxk
* @Date 2019/4/15 20:52
* @Version 1.0
*/
public class Node {
Data data;
ArrayList<Node> child = new ArrayList<>();
Node parent;
public Data getData() {
return data;
}
public void setData(Data data) {
this.data = data;
}
public ArrayList<Node> getChild() {
return child;
}
public void setChild(ArrayList<Node> child) {
this.child = child;
}
public Node getParent() {
return parent;
}
public void setParent(Node parent) {
this.parent = parent;
}
/**
* 初始化树
* @return
*/
public static Node init() {
Node node = new Node();
node.setParent(null);
node.setData(null);
node.setChild(new ArrayList<>());
return node;
}
/**
* 添加ArrayList到树
*/
public static Node putList2Tree(ArrayList<String> list, Node root) {
Node curNode = null;
Node parent = root;
int cnt = 0;
while(true) {
if (list == null || list.size() <= cnt) return root;
ArrayList<Node> child = parent.getChild();
for (int i = 0; i < child.size(); i++) {
if (child.get(i).getData().getName().equals(list.get(cnt))) {
curNode = child.get(i);
break;
}
}
/**
* 没有找到
*/
if (curNode == null) {
curNode = new Node();
curNode.setData(new Data(list.get(cnt), 1));
curNode.setParent(parent);
curNode.setChild(new ArrayList<>());
child.add(curNode);
} else {
curNode.getData().setCnt(curNode.getData().getCnt() + 1);
}
parent = curNode;
cnt++;
curNode = null;
Main.leaf.add(parent);
}
}
/**
* 深搜遍历
* @param root
*/
public static void dfs(Node root) {
ArrayList<Node> child = root.getChild();
for(int i = 0; i < child.size(); i++) {
System.out.println(child.get(i).getData().getName() + " " + child.get(i).getData().getCnt());
if(child != null) dfs(child.get(i));
}
}
/**
* 获取某个节点的所有子节点包含data.name的和
* @param node
* @param data
* @return
*/
public static int getAllChildCount(Node node, Data data) {
int t = 0;
if(node == null) return 0;
if(node.getData().getName().equals(data.getName())) {
t += node.getData().getCnt();
return t;
}
for(Node n : node.getChild()) {
if(node.getData().getName().equals(data.getName())) {
t += node.getData().getCnt();
return t;
} else {
t += getAllChildCount(n, data);
}
}
return t;
}
/**
* 获取节点深度
* @param node
* @return
*/
public static int getDepth(Node node) {
if(node.getParent() == null) return 1;
return getDepth(node.getParent()) + 1;
}
}
package com.clxk1997;
import java.util.*;
public class Main {
/**
* 数据集最大值
*/
private static final int MAXN = 3000;
/**
* 原始数据集
*/
private static ArrayList<String> data[] = new ArrayList[MAXN];
/**
* 项头表
*/
private static ArrayList<Data> list = new ArrayList<>();
/**
*
*/
private static Node root;
private static Node curNode;
/**
* 数据集大小和最小支持度
*/
private static int datacnt;
private static int minsupport;
/**
* 所有叶节点的集合
*/
public static List<Node> leaf = new ArrayList<>();
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
System.out.println("请输入数据集大小: ");
datacnt = scanner.nextInt();
System.out.println("请输入最小支持度: ");
minsupport = scanner.nextInt();
System.out.println("请输入原始数据集: ");
scanner.nextLine();
for(int i = 0; i < datacnt; i++) {
data[i] = new ArrayList<>();
String s = scanner.nextLine();
String[] split = s.split("\\s");
for(int j = 0; j < split.length; j++) {
data[i].add(split[j]);
}
}
/**
* 处理数据
*/
solve();
}
/**
* 处理数据
*/
public static void solve() {
for(ArrayList<String> it : data) {
if(it == null) break;
for(String s: it) {
putIntoList(s);
}
}
/**
* 排序输出满足的项头表
*/
sortAndOut();
/**
* 输出排序后的数据集
*/
getSortedList();
/**
* 将list放入Tree中
*/
root = Node.init();
for(int i = 0; i < data.length; i++) {
if(data[i] == null) break;
Node.putList2Tree(data[i],root);
}
/**
* 遍历整棵树
*/
System.out.println("深搜遍历树: ");
Node.dfs(root);
/**
* 倒序遍历项头表寻找频繁项集
*/
ArrayList<Data> ansdata = new ArrayList<>();
int maxd = 0, maxlen = 0;
for(int i = list.size() - 1; i >= 0; i--) {
int curd = 0;
Node node = getDepthNode(list.get(i));
ArrayList<Data> curdata = new ArrayList<>();
searchFrequence(node, list.get(i), curdata);
System.out.println("项头: " + list.get(i).getName() + "的最大频繁项集是: ");
curd = outData(curdata);
if(curdata.size() > maxlen || (curdata.size() == maxlen && curd > maxd)) {
maxlen = curdata.size();
ansdata = (ArrayList<Data>) curdata.clone();
}
}
System.out.println("所以最终频繁项集为: ");
outData(ansdata);
}
/**
* 获取data.name在树中最深的节点
* @param data
* @return
*/
public static Node getDepthNode(Data data) {
Node node = null;
int depth = 0;
for(Node n : leaf) {
if(n.getData().getName().equals(data.getName())) {
int cnt = Node.getDepth(n);
if(cnt > depth) {
depth = cnt;
node = n;
}
}
}
return node;
}
/**
* 寻找频繁项集
* @param node
* @param data
* @param curdata
*/
public static void searchFrequence(Node node, Data data, ArrayList<Data> curdata) {
if(node.getData() == null) return;
Data data1 = new Data();
data1.setName(node.getData().getName());
int t = Node.getAllChildCount(node, data);
data1.setCnt(t);
curdata.add(data1);
searchFrequence(node.getParent(), data, curdata);
}
/**
* 输出排序后的数据集
*/
public static void getSortedList() {
ArrayList<String> cur[] = new ArrayList[MAXN];
for(int i = 0; i < datacnt; i++) {
cur[i] = new ArrayList<>();
ArrayList<String> str = data[i];
for(int j = 0; j < list.size(); j++) {
if(str.contains(list.get(j).getName())) {
cur[i].add(list.get(j).getName());
}
}
}
data = cur.clone();
System.out.println("排序后的数据集: ");
for(int i = 0; i < data.length; i++) {
if(data[i] == null) break;
System.out.println(Arrays.toString(data[i].toArray()));
}
}
public static void sortAndOut() {
list.sort(new Comparator<Data>() {
@Override
public int compare(Data o1, Data o2) {
if(o1.getCnt() < o2.getCnt()) return 1;
else if(o1.getCnt() == o2.getCnt()) return 0;
return -1;
}
});
for(int i = 0; i < list.size(); i++) {
if(list.get(i).getCnt() < minsupport) {
list.remove(i);
i--;
}
}
System.out.println("满足支持度的项头表: ");
for(int i = 0; i < list.size(); i++) {
System.out.println(list.get(i).getName() + " " + list.get(i).getCnt());
}
}
/**
* 将字符串放入List,自动合并
* @param s
*/
public static void putIntoList(String s) {
Data data = new Data();
data.setName(s);
data.setCnt(1);
for(int i = 0; i < list.size(); i++) {
if(list.get(i).getName().equals(s)) {
list.set(i,new Data(s, list.get(i).getCnt() + 1));
return;
}
}
list.add(data);
}
public static int outData(ArrayList<Data> curdata) {
System.out.print("[");
int curd = 0;
for(int j = 0; j < curdata.size(); j++) {
if(j != 0) System.out.print(" ");
System.out.print(curdata.get(j).getName() + "," + curdata.get(j).getCnt());
curd += curdata.get(j).getCnt();
}
System.out.println("]");
return curd;
}
}