王蓉 - 完美

import pandas as pdimport numpy as npfrom sklearn.model_selection import train_test_splitfrom sklearn.naive_bayes import GaussianNBfrom sklearn.metrics import accuracy_score, classification_report, confusion_matriximport matplotlib.pyplot as pltimport seaborn as snsfrom sklearn.svm import SVCfrom sklearn.tree import DecisionTreeClassifierfrom sklearn.ensemble import RandomForestClassifier# 设置中文字体plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']plt.rcParams['axes.unicode_minus'] = False# 类别名称class_names = ["T恤", "裤子", "套头衫", "连衣裙", "外套", "凉鞋", "衬衫", "运动鞋", "包", "短靴"]# -------------------------- 1. 数据准备(读取本地CSV文件) --------------------------# 本地CSV文件路径(Windows路径用反斜杠\,或加r转义)train_csv_path = r"C:\Users\王蓉\Desktop\机器学习\fashion-mnist_train.csv"test_csv_path = r"C:\Users\王蓉\Desktop\机器学习\fashion-mnist_test.csv"# 读取CSV文件(Fashion-MNIST的CSV格式:第一列为标签label,后续784列为像素值)print("正在读取本地CSV文件...")train_df = pd.read_csv(train_csv_path)test_df = pd.read_csv(test_csv_path)# 分离特征(X)和标签(y)# 训练集:label列是标签,其余列是像素特征X_train = train_df.drop("label", axis=1).values # (60000, 784)y_train = train_df["label"].values # (60000,)# 测试集:同上X_test = test_df.drop("label", axis=1).values # (10000, 784)y_test = test_df["label"].values # (10000,)# 归一化:将像素值从[0,255]缩放到[0,1](CSV中像素已为0-255整数)X_train = X_train / 255.0X_test = X_test / 255.0# 划分训练集与验证集(8:2拆分,保持类别分布一致)X_train, X_val, y_train, y_val = train_test_split( X_train, y_train, test_size=0.2, random_state=42, stratify=y_train)print(f"数据维度概况:")print(f"训练集:X_train={X_train.shape}, y_train={y_train.shape}")print(f"验证集:X_val={X_val.shape}, y_val={y_val.shape}")print(f"测试集:X_test={X_test.shape}, y_test={y_test.shape}")# -------------------------- 2. 构建高斯朴素贝叶斯模型 --------------------------# 高斯朴素贝叶斯关键参数说明:# - priors: 类别先验概率(默认None,使用训练集类别分布作为先验)# - var_smoothing: 方差平滑项(默认1e-9,避免特征方差为0导致概率计算出错)clf = GaussianNB( priors=None, # 采用训练集类别分布作为先验(符合朴素贝叶斯默认逻辑) var_smoothing=1e-9 # 默认平滑项,确保数值稳定性)# -------------------------- 3. 模型训练与评估 --------------------------# 训练模型print("\n开始训练高斯朴素贝叶斯模型...")clf.fit(X_train, y_train)# 1. 训练集性能评估(观察模型拟合程度)y_pred_train = clf.predict(X_train)accuracy_train = accuracy_score(y_train, y_pred_train)print(f"\n===== 训练集性能 =====")print(f"训练集准确率: {accuracy_train:.4f}")# 2. 验证集性能评估(初步验证模型泛化能力)y_pred_val = clf.predict(X_val)accuracy_val = accuracy_score(y_val, y_pred_val)print(f"\n===== 验证集性能 =====")print(f"验证集准确率: {accuracy_val:.4f}")print("验证集分类报告:")print(classification_report(y_val, y_pred_val, target_names=[ "T恤", "裤子", "套头衫", "连衣裙", "外套", "凉鞋", "衬衫", "运动鞋", "包", "短靴"])) # 补充类别名称,使报告更易读print("验证集混淆矩阵:")print(confusion_matrix(y_val, y_pred_val))# 3. 测试集性能评估(最终模型性能)y_pred_test = clf.predict(X_test)accuracy_test = accuracy_score(y_test, y_pred_test)print(f"\n===== 测试集性能 =====")print(f"测试集准确率: {accuracy_test:.4f}")print("测试集分类报告:")print(classification_report(y_test, y_pred_test, target_names=[ "T恤", "裤子", "套头衫", "连衣裙", "外套", "凉鞋", "衬衫", "运动鞋", "包", "短靴"]))print("测试集混淆矩阵:")print(confusion_matrix(y_test, y_pred_test))#性能分析print("\n" + "="*50)print("性能分析")print("="*50)models = { '高斯朴素贝叶斯': clf, '支持向量机(SVM)': SVC(kernel='rbf', gamma='scale', random_state=42), '决策树': DecisionTreeClassifier(max_depth=10, random_state=42), '随机森林': RandomForestClassifier(n_estimators=100, random_state=42)}results = {}for name, model in models.items(): print(f"训练{name}...") if name != '高斯朴素贝叶斯': # 高斯朴素贝叶斯已经训练过 model.fit(X_train, y_train) # 所有模型都需要进行预测 y_pred = model.predict(X_test) accuracy = accuracy_score(y_test, y_pred) results[name] = accuracy print(f" {name}测试准确率: {accuracy:.4f}")# 可视化对比plt.figure(figsize=(10, 6))models_list = list(results.keys())acc_list = [results[m] for m in models_list]bars = plt.bar(models_list, acc_list, color=['blue', 'green', 'orange', 'red'])plt.ylabel('测试准确率')plt.title('不同分类模型在Fashion-MNIST上的性能对比')plt.ylim(0, 1.0)# 添加数值标签for bar, acc in zip(bars, acc_list): height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2., height + 0.01, f'{acc:.4f}', ha='center', va='bottom')plt.tight_layout()plt.show()# 4.2 类别区分能力分析print("\n" + "="*50)print("4.2 类别区分能力分析")print("="*50)# 获取测试集混淆矩阵(使用之前训练好的clf的预测结果)cm_test = confusion_matrix(y_test, y_pred_test)# 可视化混淆矩阵plt.figure(figsize=(12, 10))sns.heatmap(cm_test, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)plt.title('测试集混淆矩阵热图')plt.ylabel('真实标签')plt.xlabel('预测标签')plt.tight_layout()plt.show()# 分析每个类别的精确率和召回率print("\n各类别性能分析:")for i, class_name in enumerate(class_names): tp = cm_test[i, i] # 真正例 fp = cm_test[:, i].sum() - tp # 假正例 fn = cm_test[i, :].sum() - tp # 假反例 precision = tp / (tp + fp) if (tp + fp) > 0 else 0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0 print(f"{class_name}: 精确率={precision:.3f}, 召回率={recall:.3f}")# 找出最容易混淆的类别对print("\n最容易混淆的类别对(前3对):")confusion_pairs = []for i in range(10): for j in range(10): if i != j and cm_test[i, j] > 0: confusion_pairs.append((i, j, cm_test[i, j]))# 按混淆数量排序confusion_pairs.sort(key=lambda x: x[2], reverse=True)for idx, (i, j, count) in enumerate(confusion_pairs[:3]): print(f"{class_names[i]} → {class_names[j]}: {count}个样本") if idx == 0: worst_pair = (class_names[i], class_names[j], count)print(f"\n最易混淆的类别: {worst_pair[0]} 和 {worst_pair[1]}")print("可能原因分析:")print("- 形状相似性: 两种服装在轮廓上可能相似")print("- 纹理相似性: 材质或图案可能相近")print("- 像素分布重叠: 在某些关键特征区域像素值分布相似")# 4.3 特征独立性假设分析print("\n" + "="*50)print("4.3 特征独立性假设分析")print("="*50)print("朴素贝叶斯的核心假设:特征条件独立性")print("\n对于图像数据,像素之间通常不满足独立性假设:")print("1. 空间相关性:相邻像素通常高度相关")print("2. 局部结构:像素形成边缘、纹理等局部模式")print("3. 全局模式:像素组成完整的物体形状")# 计算像素相关性(示例)print("\n相关性分析示例(随机选取50个像素):")np.random.seed(42)sample_pixels = np.random.choice(784, 50, replace=False)X_sample = X_train[:, sample_pixels]corr_matrix = np.corrcoef(X_sample, rowvar=False)# 计算平均绝对相关性(忽略对角线)np.fill_diagonal(corr_matrix, 0)mean_abs_corr = np.mean(np.abs(corr_matrix))print(f"平均绝对相关系数: {mean_abs_corr:.4f}")print(f"相关系数 > 0.3的比例: {np.sum(np.abs(corr_matrix) > 0.3) / (corr_matrix.size - 50):.4f}")print("\n独立性假设对模型性能的影响:")print("1. 正面影响:")print(" - 简化计算:从估计784维联合分布变为784个一维分布")print(" - 减少过拟合:参数数量从O(n²)降到O(n)")print(" - 训练快速:计算复杂度大幅降低")print("2. 负面影响:")print(" - 忽略空间结构:无法捕捉图像中的局部模式")print(" - 信息损失:丢弃了像素间的关联信息")print(" - 性能受限:准确率通常低于考虑空间相关性的模型")print("\n" + "="*50)print("实验总结")print("="*50)print(f"1. 高斯朴素贝叶斯在Fashion-MNIST上达到 {accuracy_test:.4f} 准确率")# 计算性能排名sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)rank = [name for name, _ in sorted_results].index('高斯朴素贝叶斯') + 1print(f"2. 与其他模型相比,性能排名: 第{rank}名")print(f"3. 最易混淆的类别: {worst_pair[0]} 和 {worst_pair[1]}")print(f"4. 特征独立性假设是性能的主要限制因素")修改完善代码
最新发布
12-03
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值