# 加载必要包
if (!require("arules")) install.packages("arules")
library(arules)
# 读取数据(需将文件路径替换为实际路径)
data <- read.csv("E:/副本expanded.csv", header = FALSE, stringsAsFactors = FALSE)
# 定义Mushroom数据集的特征名(按UCI标准特征顺序)
feature_names <- c("class", "cap-shape", "cap-surface", "cap-color", "bruises", "odor",
"gill-attachment", "gill-spacing", "gill-size", "gill-color", "stalk-shape",
"stalk-root", "stalk-surface-above-ring", "stalk-surface-below-ring",
"stalk-color-above-ring", "stalk-color-below-ring", "veil-type", "veil-color",
"ring-number", "ring-type", "spore-print-color", "population", "habitat")
# 数据预处理:将每行转换为"特征名=值"格式
preprocess_transaction <- function(row) {
non_empty <- row[!is.na(row) & row != ""] # 去除空值
paste(feature_names[1:length(non_empty)], non_empty, sep = "=")
}
# 生成事务列表
transactions_list <- apply(data, 1, preprocess_transaction)
# 拆分字符向量为列表
transactions_list_split <- strsplit(transactions_list, " ") # 使用实际的分隔符替换空格
# 转换为transactions
transactions_obj <- as(transactions_list_split, "transactions")
# 保存预处理后的数据(可选)
write(transactions_obj, file = "preprocessed_mushroom.csv", format = "basket", sep = ",")
# 1. 计算项集支持度
calculate_support <- function(itemset, transactions) {
count <- sum(sapply(transactions, function(trans) all(itemset %in% trans)))
count / length(transactions)
}
# 2. 生成1-项候选集
generate_c1 <- function(transactions) {
all_items <- unique(unlist(transactions))
lapply(all_items, function(item) c(item))
}
# 3. 剪枝生成1-项频繁集
prune_c1 <- function(c1, transactions, min_sup) {
support_df <- data.frame(
itemset = I(lapply(c1, function(x) x)),
support = sapply(c1, calculate_support, transactions)
)
support_df[support_df$support >= min_sup, ]
}
# 4. 候选集连接(生成k-项候选集)
join_itemsets <- function(frequent_prev) {
k <- length(frequent_prev$itemset[[1]]) + 1
itemsets <- frequent_prev$itemset
result <- list()
for (i in 1:(nrow(frequent_prev) - 1)) {
for (j in (i + 1):nrow(frequent_prev)) {
item_i <- itemsets[[i]]
item_j <- itemsets[[j]]
# 前k-2项相同则连接
if (identical(item_i[1:(k - 2)], item_j[1:(k - 2)])) {
new_itemset <- sort(unique(c(item_i, item_j)))
if (length(new_itemset) == k) {
result <- c(result, list(new_itemset))
}
}
}
}
unique(result)
}
# 5. 剪枝k-项候选集(子集均为频繁集)
prune_ck <- function(ck, frequent_prev, min_sup, transactions) {
# 筛选子集均为频繁集的候选集
valid_ck <- Filter(function(itemset) {
subsets <- combn(itemset, length(itemset) - 1, simplify = FALSE)
all(sapply(subsets, function(sub) any(sapply(frequent_prev$itemset, identical, sub))))
}, ck)
# 计算支持度并筛选
if (length(valid_ck) == 0) return(data.frame(itemset = I(list()), support = numeric()))
support_df <- data.frame(
itemset = I(valid_ck),
support = sapply(valid_ck, calculate_support, transactions)
)
support_df[support_df$support >= min_sup, ]
}
# 6. 完整频繁项集挖掘
apriori_frequent <- function(transactions, min_sup) {
# 生成1-项频繁集
c1 <- generate_c1(transactions)
L1 <- prune_c1(c1, transactions, min_sup)
all_frequent <- L1
k <- 2
while (nrow(all_frequent[lengths(all_frequent$itemset) == (k - 1), ]) > 0) {
# 生成k-项候选集
frequent_prev <- all_frequent[lengths(all_frequent$itemset) == (k - 1), ]
ck <- join_itemsets(frequent_prev)
if (length(ck) == 0) break
# 剪枝得到k-项频繁集
Lk <- prune_ck(ck, frequent_prev, min_sup, transactions)
if (nrow(Lk) == 0) break
# 合并所有频繁集
all_frequent <- rbind(all_frequent, Lk)
k <- k + 1
}
all_frequent
}
# 7. 生成关联规则(从频繁项集)
generate_rules <- function(frequent_itemsets, transactions, min_conf) {
rules <- data.frame(
lhs = I(list()), # 规则左部
rhs = I(list()), # 规则右部
support = numeric(),
confidence = numeric()
)
# 仅处理长度>=2的频繁项集
frequent_large <- frequent_itemsets[lengths(frequent_itemsets$itemset) >= 2, ]
for (i in 1:nrow(frequent_large)) {
itemset <- frequent_large$itemset[[i]]
itemset_support <- frequent_large$support[i]
# 生成所有非空真子集作为左部
for (lhs_len in 1:(length(itemset) - 1)) {
lhs_list <- combn(itemset, lhs_len, simplify = FALSE)
for (lhs in lhs_list) {
rhs <- setdiff(itemset, lhs)
# 计算左部支持度
lhs_support <- calculate_support(lhs, transactions)
if (lhs_support == 0) next
# 计算置信度
confidence <- itemset_support / lhs_support
# 筛选强规则
if (confidence >= min_conf) {
rules <- rbind(rules, data.frame(
lhs = I(list(lhs)),
rhs = I(list(rhs)),
support = itemset_support,
confidence = confidence
))
}
}
}
}
# 去重并按置信度排序
rules <- unique(rules)
rules[order(-rules$confidence), ]
}
# 1. 支持度敏感性分析(记录频繁项集数量和运行时间)
support_analysis <- function(transactions, min_conf, support_values) {
result <- data.frame(
min_sup = numeric(),
frequent_count = numeric(),
runtime = numeric()
)
for (min_sup in support_values) {
start_time <- Sys.time()
# 挖掘频繁项集
frequent <- apriori_frequent(transactions, min_sup)
end_time <- Sys.time()
runtime <- as.numeric(difftime(end_time, start_time, units = "secs"))
result <- rbind(result, data.frame(
min_sup = min_sup,
frequent_count = nrow(frequent),
runtime = runtime
))
cat(sprintf("min_sup=%.2f: 频繁项集数=%d, 运行时间=%.2f秒\n",
min_sup, nrow(frequent), runtime))
}
# 保存结果
write.csv(result, "support_analysis.csv", row.names = FALSE)
return(result)
}
# 2. 置信度敏感性分析(记录强规则数量)
confidence_analysis <- function(transactions, min_sup, confidence_values) {
result <- data.frame(
min_conf = numeric(),
rule_count = numeric()
)
# 先挖掘频繁项集(固定支持度)
frequent <- apriori_frequent(transactions, min_sup)
for (min_conf in confidence_values) {
# 生成强规则
rules <- generate_rules(frequent, transactions, min_conf)
result <- rbind(result, data.frame(
min_conf = min_conf,
rule_count = nrow(rules)
))
cat(sprintf("min_conf=%.2f: 强规则数=%d\n", min_conf, nrow(rules)))
}
# 保存结果
write.csv(result, "confidence_analysis.csv", row.names = FALSE)
return(result)
}
# 运行分析(示例参数,可根据需求调整)
support_values <- c(0.1, 0.2, 0.3, 0.4, 0.5) # 支持度阈值范围
confidence_values <- c(0.7, 0.8, 0.9, 0.95) # 置信度阈值范围
min_conf_fixed <- 0.8 # 支持度分析的固定置信度
min_sup_fixed <- 0.2 # 置信度分析的固定支持度
# 执行分析
support_result <- support_analysis(transactions_list, min_conf_fixed, support_values)
confidence_result <- confidence_analysis(transactions_list, min_sup_fixed, confidence_values)
# 可视化分析结果(可选)
if (!require("ggplot2")) install.packages("ggplot2")
library(ggplot2)
# 支持度-频繁项集数关系
ggplot(support_result, aes(x = min_sup, y = frequent_count)) +
geom_line(color = "blue", linewidth = 1.2) +
labs(x = "最小支持度", y = "频繁项集数量", title = "支持度对频繁项集数量的影响") +
theme_minimal()
# 支持度-运行时间关系
ggplot(support_result, aes(x = min_sup, y = runtime)) +
geom_line(color = "red", linewidth = 1.2) +
labs(x = "最小支持度", y = "运行时间(秒)", title = "支持度对运行时间的影响") +
theme_minimal()
# 置信度-强规则数关系
ggplot(confidence_result, aes(x = min_conf, y = rule_count)) +
geom_line(color = "green", linewidth = 1.2) +
labs(x = "最小置信度", y = "强规则数量", title = "置信度对强规则数量的影响") +
theme_minimal()
# 挖掘最终强规则(设置合理阈值,示例:min_sup=0.2, min_conf=0.8)
final_frequent <- apriori_frequent(transactions_list, min_sup = 0.2)
final_rules <- generate_rules(final_frequent, transactions_list, min_conf = 0.8)
# 筛选与蘑菇可食用性相关的规则(包含class=EDIBLE或class=POISONOUS)
poison_edible_rules <- final_rules[
sapply(final_rules$lhs, function(x) any(grepl("class=", x))) |
sapply(final_rules$rhs, function(x) any(grepl("class=", x))),
]
# 按“置信度降序+支持度降序”排序,取Top5规则
top5_rules <- poison_edible_rules[
order(-poison_edible_rules$confidence, -poison_edible_rules$support),
][1:5, ]
# 格式化输出规则
print_top5_rules <- function(top5_rules) {
for (i in 1:nrow(top5_rules)) {
lhs_str <- paste(top5_rules$lhs[[i]], collapse = ", ")
rhs_str <- paste(top5_rules$rhs[[i]], collapse = ", ")
cat(sprintf("规则%d: %s => %s\n", i, lhs_str, rhs_str))
cat(sprintf(" 支持度: %.3f, 置信度: %.3f\n\n",
top5_rules$support[i], top5_rules$confidence[i]))
}
}
# 输出Top5规则
cat("Top5有意义的关联规则(与可食用性相关):\n\n")
print_top5_rules(top5_rules)
# 保存Top5规则
write.csv(top5_rules, "top5_mushroom_rules.csv", row.names = FALSE)
帮我在不改变数据内容的情况下调整这串代码使其在R中能运行出结果