FPGrowth算法的Python实现

本文用Python实现了FPGrowth算法,并用kaggle上的超市数据集进行了测试。

项头表的定义

class head_node(object):
    def __init__(self, head_name, num, soft_link):
        self.head_name = head_name
        self.num = num
        self.soft_link = soft_link
        
    def show(self):
        """Show the description of the present node"""
        pass

对每一个item的名称,出现的次数,以及对应的第一个FP树中的节点链接进行了存储。

FP树节点的定义

class FP_node(object):
    def __init__(self, node_name, val, parent, soft_link):
        self.node_name = node_name  # 当前节点的名称
        self.val = val              # 当前节点的次数
        self.parent = parent        # 当前节点的双亲结点
        self.children = {}          # 当前节点的孩子节点们,用字典存储,key为str,value为FP_node
        self.soft_link = soft_link  # 存储项头表指针的下一个节点
        
    def show(self):
        """show the description of the present FP_node"""
        pass

以上两个节点定义中的show方法并没有写,可以自行定义

主体方法
仿照sklearn接口的设计方式,来实现FPGrowth算法。

class FPGrowth(object):
    def __init__(self, ThresHold):
        self.ThresHold = ThresHold  # 设定阈值,对item进行过滤
        self.mapping = {}
        self.head_table = []  # 项头表
        self.root_node = FP_node("root", None, None, None)  # FP树的根节点
        
    def fit(self, dataset):
        """
        仿照sklearn写的fit接口
        :param dataset: 输入数据 
        """
        self.scan_data(dataset) # 扫描数据
        self.head_table_mapping = {}
        for name, num in self.mapping.items():
            node = head_node(name, num, None)
            self.head_table_mapping[name] = node
            self.head_table.append(node)
        
        # 对head_table 根据node值进行降序排序
        self.head_table = sorted(self.head_table, key = lambda x : x.num, reverse=True)
        
        # 对原始数据的第二次扫描, 改变数据集中的数据
        for sublist in dataset:
            # 仅保留在head_table中存在的字母
            sublist[:] = [item for item in sublist if item in [node.head_name for node in self.head_table]]
        
            # 按照 head_table 中的 num 值进行降序排序
            sublist.sort(key=lambda item: next((node.num for node in self.head_table if node.head_name == item), 0), reverse=True)
            
        self.Create_FP_tree(dataset) # 生成一个FP树

        self.cond_base = self.Search_FP_tree() # 搜索FP树
        
    def predict(self):
        """输出条件模式基"""
        for key, inner_dict in self.cond_base.items():
            if inner_dict:  # 只输出非空字典
                # 对内部字典降序输出
                sorted_dict = dict(sorted(inner_dict.items(), key=lambda x : x[1], reverse=True))
                print(key , " : ", sorted_dict)
            
    
    def scan_data(self, dataset):
        """扫描一遍数据并删除低于阈值的那一部分"""
        for sublist in dataset:
            for item in sublist:
                if item in self.mapping:
                    self.mapping[item] += 1
                else:
                    self.mapping[item] = 1

        # 删除低于阈值的部分
        self.N = len(dataset)
        to_del = []
        for item, nums in self.mapping.items():
            if nums / self.N  <= self.ThresHold:
                to_del.append(item)
        
        for item in to_del:
            del self.mapping[item]
        
        
        
    def Create_FP_tree(self, dataset):
        """
        生成一个FP树
        """
        for sublist in dataset:
            pres_node = self.root_node  # 每个数据都从根节点重新生成
            for item in sublist:
                if item in pres_node.children:  # 当前的node在parent中存在了,只需要将val + 1即可
                    pres_node.children[item].val += 1
                    pres_node = pres_node.children[item]  # 往下走一步
                else:  # 如果不存在这个孩子节点
                    tmp_node = FP_node(item, 1, pres_node, None)
                    pres_node.children[item] = tmp_node
                    
                    # z好到这个字母对应的软链接在哪
                    soft_node = self.head_table_mapping[item]
                    while soft_node.soft_link != None:
                        soft_node = soft_node.soft_link
                    soft_node.soft_link = tmp_node  # 上一个soft_node 接上这个tmp_node
                    
                    pres_node = tmp_node # 当前节点就到了 tmp_node 这儿
        
    def Search_FP_tree(self):
        """
        对FP树进行反向搜索, 生成条件模式基
        """
        result = {}
        for idx in range(len(self.head_table) - 1, -1, -1):
            HeadNode = self.head_table[idx] # 获取当前的项头表中的点
            pres_cond_base = {}  # 存储当前节点的条件模式基
            soft_node = HeadNode.soft_link
            while soft_node != None:
                parent = soft_node.parent
                while parent.node_name != "root":
                    if parent.node_name not in pres_cond_base:
                        pres_cond_base[parent.node_name] = soft_node.val
                    else:
                        pres_cond_base[parent.node_name] += soft_node.val
                    
                    parent = parent.parent # 向上走到当前节点的双亲节点 直到走到root_node
                soft_node = soft_node.soft_link
            # 到这儿pres_cond_base已经生成了,但是要将低于Threshold的值给去掉
            to_del = []
            for item, nums in pres_cond_base.items():
                if nums / self.N <= self.ThresHold:
                    to_del.append(item)
            for item in to_del:
                del pres_cond_base[item]
            
            result[HeadNode.head_name] = pres_cond_base
        
        return result            

其中fit函数是主要的入口,该方法接收数据作为输入,执行了FPGrowth算法的主要流程。

FPGrowth算法的主要流程为:

  • 扫描第一遍数据,将低于支持度阈值的数据剔除。
  • 扫描第二遍数据,对原始数据进行排序,为构建FPGrowth树做准备。
  • 构建FP树和项头表。
  • 反向挖掘FP树,得到每个item的频繁项集

在kaggle提供的数据集中,对算法进行了测试https://www.kaggle.com/datasets/parisanahmadi/market-basket-optimisation?resource=download

测试的代码如下:

import pandas as pd
import numpy as np

data = pd.read_csv("Market_Basket_Optimisation.csv", header=None)

data_list = data.values.tolist()
data_without_nan = []
for sublist in data_list:
    new_sub = []
    for item in sublist:
        if type(item) == str:
            new_sub.append(item)
        else:
            break
    data_without_nan.append(new_sub)

from FPGrowth import *  # 将上述的所有代码都写在FPGrowth.py文件中

ThresHold = 0.05
model = FPGrowth(ThresHold)
model.fit(data_without_nan)
model.predict()

输出的结果如下:

chocolate  :  {'mineral water': 396}
spaghetti  :  {'mineral water': 448}
eggs  :  {'mineral water': 382}

可以通过更改ThresHold的值来得到不同的输出结果。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值