2013年11月19日注:以下算法中,combine算法实现不正确,应该是从已有的频繁中来产生。需要进一步修改
=================================================================================
Apriori算法原理:
如果某个项集是频繁的,那么它所有的子集也是频繁的。如果一个项集是非频繁的,那么它所有的超集也是非频繁的。
示意图
图一:
图二:
package cn.ffr.frequent.apriori;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* Apriori的核心代码实现
* @author neu_fufengrui@163.com
*/
public class Apriori {
public static final String STRING_SPLIT = ",";
/**
* 主要的计算方法
* @param data 数据集
* @param minSupport 最小支持度
* @param maxLoop 最大执行次数,设NULL为获取最终结果
* @param containSet 结果中必须包含的子集
* @return
*/
public Map<String, Double> compute(List<String[]> data, Double minSupport, Integer maxLoop, String[] containSet){
//校验
if(data == null || data.size() <= 0){
return null;
}
//初始化
Map<String, Double> result = new HashMap<String, Double>();
Object[] itemSet = getDataUnitSet(data);
int loop = 0;
//核心循环处理过程
while(true){
//重要步骤一:合并,产生新的频繁集
Set<String> keys = combine(result.keySet(), itemSet);
result.clear();//移除之前的结果
for(String key : keys){
result.put(key, computeSupport(data, key.split(STRING_SPLIT)));
}
//重要步骤二:修剪,去除支持度小于条件的。
cut(result, minSupport, containSet);
loop++;
//输出计算过程
System.out.println("loop ["+loop+"], result : "+result);
//循环结束条件
if(result.size() <= 0){
break;
}
if(maxLoop != null && maxLoop > 0 && loop >= maxLoop){//可控制循环执行次数
break;
}
}
return result;
}
/**
* 计算子集的支持度
*
* 支持度 = 子集在数据集中的数据项 / 总的数据集的数据项
*
* 数据项的意思是一条数据。
* @param data 数据集
* @param subSet 子集
* @return
*/
public Double computeSupport(List<String[]> data, String[] subSet){
Integer value = 0;
for(int i = 0; i < data.size(); i++){
if(contain(data.get(i), subSet)){
value ++;
}
}
return value*1.0/data.size();
}
/**
* 获得初始化唯一的数据集,用于初始化
* @param data
* @return
*/
public Object[] getDataUnitSet(List<String[]> data){
List<String> uniqueKeys = new ArrayList<String>();
for(String[] dat : data){
for(String da : dat){
if(!uniqueKeys.contains(da)){
uniqueKeys.add(da);
}
}
}
return uniqueKeys.toArray();
}
/**
* 合并src和target来获取频繁集
* 增加频繁集的计算维度
* @param src
* @param target
* @return
*/
public Set<String> combine(Set<String> src, Object[] target){
Set<String> dest = new HashSet<String>();
if(src == null || src.size() <= 0){
for(Object t : target){
dest.add(t.toString());
}
return dest;
}
for(String s : src){
for(Object t : target){
if(s.indexOf(t.toString())<0){
String key = s+STRING_SPLIT+t;
if(!contain(dest, key)){
dest.add(key);
}
}
}
}
return dest;
}
/**
* dest集中是否包含了key
* @param dest
* @param key
* @return
*/
public boolean contain(Set<String> dest, String key){
for(String d : dest){
if(equal(d.split(STRING_SPLIT), key.split(STRING_SPLIT))){
return true;
}
}
return false;
}
/**
* 移除结果中,支持度小于所需要的支持度的结果。
* @param result
* @param minSupport
* @return
*/
public Map<String, Double> cut(Map<String, Double> result, Double minSupport, String[] containSet){
for(Object key : result.keySet().toArray()){//防止 java.util.ConcurrentModificationException,使用keySet().toArray()
if(minSupport != null && minSupport > 0 && minSupport < 1 && result.get(key) < minSupport){//比较支持度
result.remove(key);
}
if(containSet != null && containSet.length > 0 && !contain(key.toString().split(STRING_SPLIT), containSet)){
result.remove(key);
}
}
return result;
}
/**
* src中是否包含dest,需要循环遍历查询
* @param src
* @param dest
* @return
*/
public static boolean contain(String[] src, String[] dest){
for(int i = 0; i < dest.length; i++){
int j = 0;
for(; j < src.length; j++){
if(src[j].equals(dest[i])){
break;
}
}
if(j == src.length){
return false;//can not find
}
}
return true;
}
/**
* src是否与dest相等
* @param src
* @param dest
* @return
*/
public boolean equal(String[] src, String[] dest){
if(src.length == dest.length && contain(src, dest)){
return true;
}
return false;
}
/**
* 主测试方法
* 测试方法,挨个去掉注释,进行测试。
*/
public static void main(String[] args) throws Exception{
//test 1
// List<String[]> data = loadSmallData();
// Long start = System.currentTimeMillis();
// Map<String, Double> result = new Apriori().compute(data, 0.5, 3, null);//求支持度大于指定值
// Long end = System.currentTimeMillis();
// System.out.println("Apriori Result [costs:"+(end-start)+"ms]: ");
// for(String key : result.keySet()){
// System.out.println("\tFrequent Set=["+key+"] & Support=["+result.get(key)+"];");
// }
//test 2
// List<String[]> data = loadMushRoomData();
// Long start = System.currentTimeMillis();
// Map<String, Double> result = new Apriori().compute(data, 0.3, 4, new String[]{"2"});//求支持度大于指定值
// Long end = System.currentTimeMillis();
// System.out.println("Apriori Result [costs:"+(end-start)+"ms]: ");
// for(String key : result.keySet()){
// System.out.println("\tFrequent Set=["+key+"] & Support=["+result.get(key)+"];");
// }
//test 3
List<String[]> data = loadChessData();
Long start = System.currentTimeMillis();
Map<String, Double> result = new Apriori().compute(data, 0.95, 3, null);//求支持度大于指定值
Long end = System.currentTimeMillis();
System.out.println("Apriori Result [costs:"+(end-start)+"ms]: ");
for(String key : result.keySet()){
System.out.println("\tFrequent Set=["+key+"] & Support=["+result.get(key)+"];");
}
}
/*
* SmallData: minSupport 0.5, maxLoop 3, containSet null, [costs: 16ms]
* MushRoomData: minSupport 0.3, maxLoop 4, containSet {"2"}, [costs: 103250ms]
* ChessData: minSupport 0.95, maxLoop 34, containSet {null, [costs: 9718ms]
*/
//测试数据集-1
public static List<String[]> loadSmallData() throws Exception{
List<String[]> data = new ArrayList<String[]>();
data.add(new String[]{"d1","d3","d4"});
data.add(new String[]{"d2","d3","d5"});
data.add(new String[]{"d1","d2","d3","d5"});
data.add(new String[]{"d2","d5"});
return data;
}
//测试数据集-2
public static List<String[]> loadMushRoomData() throws Exception{
String link = "http://fimi.ua.ac.be/data/mushroom.dat";
URL url = new URL(link);
BufferedReader reader = new BufferedReader(new InputStreamReader(url.openStream()));
String temp = reader.readLine();
List<String[]> result = new ArrayList<String[]>();
int lineNumber = 0;
while(temp != null){
System.out.println("reading data... [No."+(++lineNumber)+"]");
String[] item = temp.split(" ");
result.add(item);
temp = reader.readLine();
}
reader.close();
return result;
}
//测试数据集-3
public static List<String[]> loadChessData() throws Exception{
String link = "http://fimi.ua.ac.be/data/chess.dat";
URL url = new URL(link);
BufferedReader reader = new BufferedReader(new InputStreamReader(url.openStream()));
String temp = reader.readLine();
List<String[]> result = new ArrayList<String[]>();
int lineNumber = 0;
while(temp != null){
System.out.println("reading data... [No."+(++lineNumber)+"]");
String[] item = temp.split(" ");
result.add(item);
temp = reader.readLine();
}
reader.close();
return result;
}
}
算法原理: