import pandas as pd
import numpy as np
from sklearn.cluster import KMeans
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import adjusted_rand_score, silhouette_score
import matplotlib.pyplot as plt
from io import StringIO
import warnings
# 忽略Matplotlib的特定警告
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
# --------------------- 1. 数据加载与处理 ---------------------
# 鸢尾花数据集
iris_data = """
Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Species
5.1,3.5,1.4,0.2,setosa
4.9,3.0,1.4,0.2,setosa
4.7,3.2,1.3,0.2,setosa
4.6,3.1,1.5,0.2,setosa
5.0,3.6,1.4,0.2,setosa
5.4,3.9,1.7,0.4,setosa
4.6,3.4,1.4,0.3,setosa
5.0,3.4,1.5,0.2,setosa
4.4,2.9,1.4,0.2,setosa
4.9,3.1,1.5,0.1,setosa
5.4,3.7,1.5,0.2,setosa
4.8,3.4,1.6,0.2,setosa
4.8,3.0,1.4,0.1,setosa
4.3,3.0,1.1,0.1,setosa
5.8,4.0,1.2,0.2,setosa
5.7,4.4,1.5,0.4,setosa
5.4,3.9,1.3,0.4,setosa
5.1,3.5,1.4,0.3,setosa
5.7,3.8,1.7,0.3,setosa
5.1,3.8,1.5,0.3,setosa
5.4,3.4,1.7,0.2,setosa
5.1,3.7,1.5,0.4,setosa
4.6,3.6,1.0,0.2,setosa
5.1,3.3,1.7,0.5,setosa
4.8,3.4,1.9,0.2,setosa
5.0,3.0,1.6,0.2,setosa
5.0,3.4,1.6,0.4,setosa
5.2,3.5,1.5,0.2,setosa
5.2,3.4,1.4,0.2,setosa
4.7,3.2,1.6,0.2,setosa
4.8,3.1,1.6,0.2,setosa
5.4,3.4,1.5,0.4,setosa
5.2,4.1,1.5,0.1,setosa
5.5,4.2,1.4,0.2,setosa
4.9,3.1,1.5,0.2,setosa
5.0,3.2,1.2,0.2,setosa
5.5,3.5,1.3,0.2,setosa
4.9,3.6,1.4,0.1,setosa
4.4,3.0,1.3,0.2,setosa
5.1,3.4,1.5,0.2,setosa
5.0,3.5,1.3,0.3,setosa
4.5,2.3,1.3,0.3,setosa
4.4,3.2,1.3,0.2,setosa
5.0,3.5,1.6,0.6,setosa
5.1,3.8,1.9,0.4,setosa
4.8,3.0,1.4,0.3,setosa
5.1,3.8,1.6,0.2,setosa
4.6,3.2,1.4,0.2,setosa
5.3,3.7,1.5,0.2,setosa
5.0,3.3,1.4,0.2,setosa
7.0,3.2,4.7,1.4,versicolor
6.4,3.2,4.5,1.5,versicolor
6.9,3.1,4.9,1.5,versicolor
5.5,2.3,4.0,1.3,versicolor
6.5,2.8,4.6,1.5,versicolor
5.7,2.8,4.5,1.3,versicolor
6.3,3.3,4.7,1.6,versicolor
4.9,2.4,3.3,1.0,versicolor
6.6,2.9,4.6,1.3,versicolor
5.2,2.7,3.9,1.4,versicolor
5.0,2.0,3.5,1.0,versicolor
5.9,3.0,4.2,1.5,versicolor
6.0,2.2,4.0,1.0,versicolor
6.1,2.9,4.7,1.4,versicolor
5.6,2.9,3.6,1.3,versicolor
6.7,3.1,4.4,1.4,versicolor
5.6,3.0,4.5,1.5,versicolor
5.8,2.7,4.1,1.0,versicolor
6.2,2.2,4.5,1.5,versicolor
5.6,2.5,3.9,1.1,versicolor
5.9,3.2,4.8,1.8,versicolor
6.1,2.8,4.0,1.3,versicolor
6.3,2.5,4.9,1.5,versicolor
6.1,2.8,4.7,1.2,versicolor
6.4,2.9,4.3,1.3,versicolor
6.6,3.0,4.4,1.4,versicolor
6.8,2.8,4.8,1.4,versicolor
6.7,3.0,5.0,1.7,versicolor
6.0,2.9,4.5,1.5,versicolor
5.7,2.6,3.5,1.0,versicolor
5.5,2.4,3.8,1.1,versicolor
5.5,2.4,3.7,1.0,versicolor
5.8,2.7,3.9,1.2,versicolor
6.0,2.7,5.1,1.6,versicolor
5.4,3.0,4.5,1.5,versicolor
6.0,3.4,4.5,1.6,versicolor
6.7,3.1,4.7,1.5,versicolor
6.3,2.3,4.4,1.3,versicolor
5.6,3.0,4.1,1.3,versicolor
5.5,2.5,4.0,1.3,versicolor
5.5,2.6,4.4,1.2,versicolor
6.1,3.0,4.6,1.4,versicolor
5.8,2.6,4.0,1.2,versicolor
5.0,2.3,3.3,1.0,versicolor
5.6,2.7,4.2,1.3,versicolor
5.7,3.0,4.2,1.2,versicolor
5.7,2.9,4.2,1.3,versicolor
6.2,2.9,4.3,1.3,versicolor
5.1,2.5,3.0,1.1,versicolor
5.7,2.8,4.1,1.3,versicolor
6.3,3.3,6.0,2.5,virginica
5.8,2.7,5.1,1.9,virginica
7.1,3.0,5.9,2.1,virginica
6.3,2.9,5.6,1.8,virginica
6.5,3.0,5.8,2.2,virginica
7.6,3.0,6.6,2.1,virginica
4.9,2.5,4.5,1.7,virginica
7.3,2.9,6.3,1.8,virginica
6.7,2.5,5.8,1.8,virginica
7.2,3.6,6.1,2.5,virginica
6.5,3.2,5.1,2.0,virginica
6.4,2.7,5.3,1.9,virginica
6.8,3.0,5.5,2.1,virginica
5.7,2.5,5.0,2.0,virginica
5.8,2.8,5.1,2.4,virginica
6.4,3.2,5.3,2.3,virginica
6.5,3.0,5.5,1.8,virginica
7.7,3.8,6.7,2.2,virginica
7.7,2.6,6.9,2.3,virginica
6.0,2.2,5.0,1.5,virginica
6.9,3.2,5.7,2.3,virginica
5.6,2.8,4.9,2.0,virginica
7.7,2.8,6.7,2.0,virginica
6.3,2.7,4.9,1.8,virginica
6.7,3.3,5.7,2.1,virginica
7.2,3.2,6.0,1.8,virginica
6.2,2.8,4.8,1.8,virginica
6.1,3.0,4.9,1.8,virginica
6.4,2.8,5.6,2.1,virginica
7.2,3.0,5.8,1.6,virginica
7.4,2.8,6.1,1.9,virginica
7.9,3.8,6.4,2.0,virginica
6.4,2.8,5.6,2.2,virginica
6.3,2.8,5.1,1.5,virginica
6.1,2.6,5.6,1.4,virginica
7.7,3.0,6.1,2.3,virginica
6.3,3.4,5.6,2.4,virginica
6.4,3.1,5.5,1.8,virginica
6.0,3.0,4.8,1.8,virginica
6.9,3.1,5.4,2.1,virginica
6.7,3.1,5.6,2.4,virginica
6.9,3.1,5.1,2.3,virginica
5.8,2.7,5.1,1.9,virginica
6.8,3.2,5.9,2.3,virginica
6.7,3.3,5.7,2.5,virginica
6.7,3.0,5.2,2.3,virginica
6.3,2.5,5.0,1.9,virginica
6.5,3.0,5.2,2.0,virginica
6.2,3.4,5.4,2.3,virginica
5.9,3.0,5.1,1.8,virginica
"""
# 读取数据
df = pd.read_csv(StringIO(iris_data))
print(f"数据集形状: {df.shape}")
print("前5行数据:")
print(df.head())
# --------------------- 2. 数据预处理 ---------------------
# 提取特征
X = df[["Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width"]].values
# 标签编码
le = LabelEncoder()
y_true = le.fit_transform(df["Species"])
species_names = le.classes_
print("\n物种编码映射:")
for i, name in enumerate(species_names):
print(f"{name} → {i}")
# --------------------- 3. K-Means聚类 ---------------------
# 创建K-Means模型
kmeans = KMeans(
n_clusters=3, # 聚类数量
init='k-means++', # 智能初始化方法
n_init=10, # 多次初始化以避免局部最优
max_iter=300, # 最大迭代次数
random_state=42 # 随机种子
)
# 训练模型
kmeans.fit(X)
# 获取结果
y_pred = kmeans.labels_ # 聚类标签
centroids = kmeans.cluster_centers_ # 聚类中心
inertia = kmeans.inertia_ # 样本到最近聚类中心的平方和
print("\n聚类完成!")
print(f"样本到最近聚类中心的平方和: {inertia:.2f}")
# --------------------- 4. 聚类结果评估 ---------------------
# 评估指标
ari = adjusted_rand_score(y_true, y_pred) # 调整兰德指数
silhouette = silhouette_score(X, y_pred) # 轮廓系数
print("\n聚类评估:")
print(f"调整兰德指数(ARI): {ari:.4f} (1.0表示完美匹配)")
print(f"轮廓系数: {silhouette:.4f} (越高表示聚类效果越好)")
# 聚类中心分析
print("\n聚类中心坐标:")
for i, center in enumerate(centroids):
print(f"Cluster {i}:")
print(f" 萼片长度: {center[0]:.2f} cm, 萼片宽度: {center[1]:.2f} cm")
print(f" 花瓣长度: {center[2]:.2f} cm, 花瓣宽度: {center[3]:.2f} cm")
# 聚类与物种对应关系
cluster_species = {}
print("\n聚类物种分布:")
for cluster in range(3):
species_in_cluster = df.loc[y_pred == cluster, 'Species'].value_counts()
dominant_species = species_in_cluster.idxmax()
cluster_species[cluster] = dominant_species
print(f"Cluster {cluster} 主要包含: {dominant_species} ({species_in_cluster[dominant_species]}个样本)")
# --------------------- 5. 可视化 ---------------------
plt.figure(figsize=(15, 10))
# 1. 花瓣特征散点图
ax1 = plt.subplot(2, 2, 1) # 为每个子图创建新的轴对象
scatter1 = ax1.scatter(
X[:, 2], X[:, 3],
c=y_pred,
cmap='viridis',
s=50,
alpha=0.7,
edgecolor='k'
)
centroids1 = ax1.scatter(
centroids[:, 2], centroids[:, 3],
marker='X', s=200, c='red', label='聚类中心'
)
ax1.set_title('花瓣长度 vs 花瓣宽度')
ax1.set_xlabel('花瓣长度 (cm)')
ax1.set_ylabel('花瓣宽度 (cm)')
ax1.legend()
ax1.grid(alpha=0.3)
# 2. 萼片特征散点图
ax2 = plt.subplot(2, 2, 2)
scatter2 = ax2.scatter(
X[:, 0], X[:, 1],
c=y_pred,
cmap='viridis',
s=50,
alpha=0.7,
edgecolor='k'
)
centroids2 = ax2.scatter(
centroids[:, 0], centroids[:, 1],
marker='X', s=200, c='red'
)
ax2.set_title('萼片长度 vs 萼片宽度')
ax2.set_xlabel('萼片长度 (cm)')
ax2.set_ylabel('萼片宽度 (cm)')
ax2.grid(alpha=0.3)
# 3. 混合特征散点图
ax3 = plt.subplot(2, 2, 3)
scatter3 = ax3.scatter(
X[:, 0], X[:, 2],
c=y_pred,
cmap='viridis',
s=50,
alpha=0.7,
edgecolor='k'
)
centroids3 = ax3.scatter(
centroids[:, 0], centroids[:, 2],
marker='X', s=200, c='red'
)
ax3.set_title('萼片长度 vs 花瓣长度')
ax3.set_xlabel('萼片长度 (cm)')
ax3.set_ylabel('花瓣长度 (cm)')
ax3.grid(alpha=0.3)
# 4. 聚类物种分布饼图
ax4 = plt.subplot(2, 2, 4)
# 准备数据:每个聚类中每个物种的数量
cluster_counts = []
for cluster in range(3):
species_in_cluster = df.loc[y_pred == cluster, 'Species'].value_counts()
# 保证顺序与species_names一致
counts = [species_in_cluster.get(s, 0) for s in species_names]
cluster_counts.append(counts)
# 转换为数组,形状为(3,3)
cluster_counts = np.array(cluster_counts)
# 设置条形图位置
bar_width = 0.6
x = np.arange(3)
# 绘制堆叠条形图
bottom = np.zeros(3)
colors = ['#ff9999','#66b3ff','#99ff99']
for i, species in enumerate(species_names):
ax4.bar(
x,
cluster_counts[:, i],
bottom=bottom,
width=bar_width,
label=species,
color=colors[i],
alpha=0.7
)
bottom += cluster_counts[:, i]
ax4.set_title('每个聚类中的物种分布')
ax4.set_xlabel('聚类')
ax4.set_ylabel('样本数量')
ax4.set_xticks(x)
ax4.set_xticklabels(['Cluster 0', 'Cluster 1', 'Cluster 2'])
ax4.legend(loc='upper right')
ax4.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig('iris_clustering_results.png', dpi=300, bbox_inches='tight')
plt.show()
# --------------------- 6. 肘部法确定最佳K值 ---------------------
plt.figure(figsize=(10, 6))
inertias = []
k_range = range(1, 10)
for k in k_range:
kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
kmeans.fit(X)
inertias.append(kmeans.inertia_)
plt.plot(k_range, inertias, 'bo-')
plt.xlabel('聚类数量 (K)')
plt.ylabel('样本到最近中心的平方和')
plt.title('肘部法确定最佳K值')
plt.xticks(k_range)
plt.grid(True)
plt.savefig('elbow_method.png', dpi=300, bbox_inches='tight')
plt.show()
# --------------------- 7. 聚类边界可视化 ---------------------
# 选择两个特征进行可视化
plt.figure(figsize=(12, 8))
# 选择花瓣长度和宽度进行可视化
x_min, x_max = X[:, 2].min() - 0.5, X[:, 2].max() + 0.5
y_min, y_max = X[:, 3].min() - 0.5, X[:, 3].max() + 0.5
# 创建网格点
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
np.arange(y_min, y_max, 0.02))
# 使用花瓣特征训练新模型
X_subset = X[:, 2:4]
kmeans_subset = KMeans(n_clusters=3, random_state=42, n_init=10)
kmeans_subset.fit(X_subset)
# 预测网格点
Z = kmeans_subset.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
# 绘制决策边界
plt.contourf(xx, yy, Z, alpha=0.4, cmap='viridis')
plt.scatter(
X_subset[:, 0], X_subset[:, 1],
c=y_pred,
cmap='viridis',
s=50,
alpha=0.8,
edgecolor='k'
)
plt.scatter(
centroids[:, 2], centroids[:, 3],
marker='X', s=200, c='red', label='聚类中心'
)
plt.title('花瓣特征的聚类边界')
plt.xlabel('花瓣长度 (cm)')
plt.ylabel('花瓣宽度 (cm)')
plt.legend()
plt.grid(alpha=0.3)
plt.savefig('cluster_boundaries.png', dpi=300, bbox_inches='tight')
plt.show()表明横纵坐标
最新发布