import math
def create_data_set():
# 创建一个示例数据集和标签 返回数据集和标签
data_set = [
[1, 1, 'yes'], # 数据
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']
]
labels = ['no surfacing', 'flippers'] # 特征标签
return data_set, labels # 返回数据集和标签
def calc_shannon_ent(data_set):
# 计算给定数据集的信息熵 输入 data_set(数据集) 输出 shannon_ent(信息熵)
num_entries = len(data_set) # 数据集中的样本数
label_counts = {} # 用于存储各标签的计数
for feat_vec in data_set:
current_label = feat_vec[-1] # 获取当前样本的标签
if current_label not in label_counts.keys():
label_counts[current_label] = 0
label_counts[current_label] += 1 # 统计每个标签的出现次数
shannon_ent = 0.0 # 初始化信息熵
for key in label_counts:
prob = float(label_counts[key]) / num_entries # 计算每个标签的概率
shannon_ent -= prob * math.log(prob, 2) # 计算熵
return shannon_ent # 返回信息熵
def split_data_set(data_set, axis, value):
# 按照给定特征划分数据集 输入 data_set(数据集)axis(划分数据集的特征)value(特征的取值)
# 输出 划分后的数据
ret_data_set = [] # 创建新的列表用于存储划分后数据集
for feat_vec in data_set:
if feat_vec[axis] == value:
reduced_feat_vec = feat_vec[:axis] # 去掉 axis特征
reduced_feat_vec.extend(feat_vec[axis + 1:]) # 去掉后的特征组合
ret_data_set.append(reduced_feat_vec) # 添加到返回列表中
return ret_data_set # 返回划分后的数据
def choose_best_feature_to_split(data_set):
# 选择最好的特征划分数据集 输入 data_set(数据集) 输出 最佳特征的索引
num_features = len(data_set[0]) - 1 # 特征数量
base_entropy = calc_shannon_ent(data_set) # 计算数据集的基本熵
best_info_gain = 0.0 # 初始化信息增益
best_feature = -1 # 最佳特征的索引
for i in range(num_features):
feat_list = [example[i] for example in data_set] # 获取数据集中所有第 i个特征的值
unique_vals = set(feat_list) # 获取特征的唯一取值
new_entropy = 0.0 # 初始化新的熵
for value in unique_vals:
sub_data_set = split_data_set(data_set, i, value) # 划分数据集
prob = len(sub_data_set) / float(len(data_set)) # 计算子数据集的概率
new_entropy += prob * calc_shannon_ent(sub_data_set) # 计算熵
info_gain = base_entropy - new_entropy # 计算信息增益
if info_gain > best_info_gain: # 如果当前信息增益大于最佳信息增益
best_info_gain = info_gain # 更新最佳信息增益
best_feature = i # 更新最佳特征
return best_feature # 返回最佳特征的索引# 创建数据集和标签
data_set, labels = create_data_set()
# 选择最佳特征
best_feature_index = choose_best_feature_to_split(data_set)
# 输出结果
print(f"最佳特征的索引是: {best_feature_index}")
print(f"对应的特征标签是: {labels[best_feature_index]}")
2521

被折叠的 条评论
为什么被折叠?



