K-Means 算法是一种常用的无监督机器学习算法,用于将一组数据点划分为预先指定数量的簇(聚类)。算法的核心思想是通过迭代优化的方式,将数据点划分到与其最近的聚类中心,然后更新聚类中心以最小化数据点与所属聚类中心的距离平方和。以下是 K-Means 算法的核心步骤:
-
初始化聚类中心:随机选择 K 个数据点作为初始聚类中心,其中 K 为预先指定的簇数。
-
分配数据点到最近的聚类中心:计算每个数据点到所有聚类中心的距离,然后将数据点分配到距离最近的聚类中心所属的簇。
-
更新聚类中心:对于每个簇,计算该簇内所有数据点的均值,将均值作为新的聚类中心。
-
重复步骤 2 和 3:重复执行步骤 2 和 3,直到聚类中心不再发生显著变化或达到预定的迭代次数。
-
输出聚类结果:算法收敛后,每个数据点都被分配到一个簇,形成了最终的聚类结果。
导入必要的库:pandas
用于数据处理,numpy
用于数值计算,matplotlib
用于绘图。
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
定义 kmeans
函数:该函数实现了 K-Means 聚类算法。K-Means 是一种常用的聚类算法,它将数据分为预先指定的聚类簇数(在此代码中设定为 n_clusters = 7
)
-
读取数据文件:从名为
china_cities.csv
的文件中读取地图坐标数据,包括省份、城市名称和经纬度坐标。 -
执行聚类操作:使用上述定义的
kmeans
函数对坐标数据进行聚类操作,将地图上的城市分成了 7 个不同的聚类簇。 -
绘制散点图:使用
matplotlib
绘制聚类结果的散点图。每个聚类簇用不同颜色和标记进行标识,聚类中心用黑色星号标记。 -
设置图表标题、坐标轴标签和图例等,最终展示绘制的图像。
def kmeans(X, n_clusters, max_iters=100):
# 随机初始化聚类中心
centers = X[np.random.choice(range(len(X)), size=n_clusters, replace=False)]
for _ in range(max_iters):
# 计算每个样本到聚类中心的距离
distances = np.sqrt(((X - centers[:, np.newaxis]) ** 2).sum(axis=2))
# 根据距离选择最近的聚类中心
labels = np.argmin(distances, axis=0)
# 更新聚类中心
new_centers = np.array([X[labels == i].mean(axis=0) for i in range(n_clusters)])
# 如果聚类中心没有变化,提前结束迭代
if np.all(centers == new_centers):
break
centers = new_centers
return labels, centers
# 读取地图坐标文件
data = pd.read_csv('china_cities.csv')
provinces = data.iloc[:, 0].values
cities = data.iloc[:, 1].values
coordinates = data.iloc[:, 2:4].values
# 执行聚类操作
n_clusters = 7 # 设置聚类簇数为7
labels, centers = kmeans(coordinates, n_clusters)
# 绘制聚类结果散点图
plt.figure(figsize=(10, 8))
colors = ['red', 'blue', 'green', 'purple', 'orange', 'yellow', 'cyan']
markers = ['o', '^', 's', 'x', 'v', 'D', 'P']
for i in range(n_clusters):
plt.scatter(coordinates[labels == i, 1], coordinates[labels == i, 0], s=50, c=colors[i], marker=markers[i], label=f'Cluster {i+1}')
# 绘制聚类中心
plt.scatter(centers[:, 1], centers[:, 0], s=100, c='black', marker='*', label='Centers')
plt.title('Major Cities in china')
plt.xlabel('North Longitude')
plt.ylabel('East Latitude')
plt.legend()
plt.show()