import numpy as np
import pandas as pd
import math
class DRSA:
def __init__(self, data, target_col, alpha=0.5):
"""
DRSA类的初始化方法,其中data为输入数据集,target_col为目标列,alpha为可分辨性阈值,默认为0.5。
"""
self.data = data
self.target_col = target_col
self.alpha = alpha
self.disc_data = self.discretize(data) # 离散化后的数据集
self.disc_mat = self.calc_discernibility_matrix(self.disc_data) # 可分辨矩阵
self.rules = self.generate_rules() # 生成的规则集
def discretize(self, data):
"""
数据离散化方法,将连续型数据转换为离散型数据。
"""
disc_data = data.copy()
for col in disc_data.columns:
if col != self.target_col:
disc_data[col] = pd.cut(data[col], 5, labels=False) # 将连续型数据离散化
return disc_data
def calc_discernibility_matrix(self, data):
"""
计算可分辨矩阵方法,根据数据集计算可分辨矩阵。
"""
disc_mat = np.zeros((len(data), len(data)))
for i in range(len(data)):
for j in range(len(data)):
disc_mat[i][j] = self.calc_discernibility(data.iloc[i], data.iloc[j]) # 计算可分辨性
return disc_mat
def calc_discernibility(self, x, y):
"""
计算可分辨性方法,根据数据x和y计算可分辨性。
"""
disc = 0
for col in self.disc_data.columns:
if col != self.target_col:
if x[col] != y[col]:
disc += 1
return disc
def calc_rule_quality(self, rule):
"""
计算规则质量方法,根据规则计算其质量。
"""
supp = len(self.data.query(rule)) # 支持度
conf = supp / len(self.data) # 置信度
# 计算可分辨性
discernibility = np.sum(self.disc_mat[np.ix_(self.data.index.isin(self.data.query(rule).index),
self.data.index.isin(self.data.query(rule).index))])
if discernibility == 0:
return 0
else:
return self.alpha * conf + (1 - self.alpha) * (1 - discernibility / len(self.data))
def generate_rules(self):
"""
生成规则方法,根据数据集生成规则。
"""
rules = []
for col in self.disc_data.columns:
if col != self.target_col:
for val in self.disc_data[col].unique():
rules.append(f"{col} == {val}")
rules.append(f"{self.target_col} == {self.data[self.target_col].mode()[0]}")
return rules
def fit(self):
"""
拟合方法,根据数据集拟合DRSA模型,生成规则。
"""
rules_quality = {}
for rule in self.rules:
rules_quality[rule] = self.calc_rule_quality(rule) # 计算规则质量
self.rules = [rule for rule in rules_quality if rules_quality[rule] > 0] # 剔除质量为0的规则
self.rules.sort(key=lambda x: rules_quality[x], reverse=True) # 按照规则质量从高到低排序
def predict(self, data):
"""
预测方法,根据输入数据预测其类别。
"""
disc_data = self.discretize(data) # 将输入数据离散化
for rule in self.rules:
if eval(rule): # 判断输入数据是否满足规则
return self.data[self.target_col].mode()[0] # 返回目标列众数作为预测结果
return self.data[self.target_col].mode()[0]
【无标题】DRSA代码
于 2023-12-11 21:58:54 首次发布