K-means实现

import numpy as np
"""
初始化:随机初始化容易陷入局部最优,kmeans++ 能显著提升稳定性和收敛速度(见实现)。

距离度量:通常用欧氏距离(平方和),如果是其他数据类型可替换度量。

空簇处理:若某簇没有点,上面我们把质心重置为“当前点集中离任何质心最远的点”。也可重新随机采样或合并簇。

收敛判定:常用质心移动量小于 tol 或 assignments 不再变化,或达 max_iter。

复杂度:每次迭代计算距离 O(n * k * d),整体 O(n k d t)(t 为迭代次数)。

数值/实现优化:大数据时避免一次性构造 n×k 矩阵(内存),可以分批或用更高效的向量化/BLAS,或使用 KD-tree / approximate methods。
"""
class KMeans:
    def __init__(self, n_clusters=3, max_iter=300, tol=1e-4, init='kmeans++', random_state=None):
        self.k = n_clusters
        self.max_iter = max_iter
        self.tol = tol
        self.init = init
        self.random_state = random_state
        self.cluster_centers_ = None
        self.labels_ = None
        self.inertia_ = None  # WCSS

    def _kmeans_plus_plus_init(self, X):
        rng = np.random.RandomState(self.random_state)
        n_samples, _ = X.shape
        centers = np.empty((self.k, X.shape[1]), dtype=X.dtype)

        # pick first center uniformly
        first = rng.randint(0, n_samples)
        centers[0] = X[first]

        # distances to nearest center
        closest_dist_sq = np.sum((X - centers[0])**2, axis=1)

        for c in range(1, self.k):
            probs = closest_dist_sq / closest_dist_sq.sum()
            idx = rng.choice(n_samples, p=probs)
            centers[c] = X[idx]
            dist_sq = np.sum((X - centers[c])**2, axis=1)
            closest_dist_sq = np.minimum(closest_dist_sq, dist_sq)

        return centers

    def _random_init(self, X):
        rng = np.random.RandomState(self.random_state)
        idx = rng.choice(X.shape[0], self.k, replace=False)
        return X[idx].copy()

    def fit(self, X):
        X = np.asarray(X, dtype=float)
        n_samples, n_features = X.shape

        # init centers
        if self.init == 'kmeans++':
            centers = self._kmeans_plus_plus_init(X)
        else:
            centers = self._random_init(X)

        labels = np.full(n_samples, -1, dtype=int)

        for it in range(self.max_iter):
            # assignment: compute distances (broadcast)
            # shape: (n_samples, k)
            dists = np.sum((X[:, None, :] - centers[None, :, :])**2, axis=2)
            new_labels = np.argmin(dists, axis=1)

            # update centers
            new_centers = np.zeros_like(centers)
            for j in range(self.k):
                pts = X[new_labels == j]
                if len(pts) == 0:
                    # empty cluster handling: re-init to a farthest point from current centers
                    farthest = np.argmax(np.min(dists, axis=1))
                    new_centers[j] = X[farthest]
                else:
                    new_centers[j] = pts.mean(axis=0)

            # check convergence (centers movement)
            center_shifts = np.linalg.norm(new_centers - centers, axis=1)
            centers = new_centers
            labels = new_labels

            if np.all(center_shifts <= self.tol):
                break

        # compute inertia
        final_dists = np.sum((X - centers[labels])**2, axis=1)
        self.inertia_ = final_dists.sum()
        self.cluster_centers_ = centers
        self.labels_ = labels
        return self

    def predict(self, X):
        X = np.asarray(X, dtype=float)
        dists = np.sum((X[:, None, :] - self.cluster_centers_[None, :, :])**2, axis=2)
        return np.argmin(dists, axis=1)

    def fit_predict(self, X):
        self.fit(X)
        return self.labels_
if __name__ ==  "__main__":
    X = np.array([
        [0, 0],
        [0, 1],
        [1, 0],
        [8, 8],
        [8, 9],
        [9, 8]
    ])
    
    # 训练 KMeans
    kmeans = KMeans(n_clusters=2, random_state=0)
    kmeans.fit(X)
    
    # 输出结果
    print("簇标签:", kmeans.labels_)
    print("质心:", kmeans.cluster_centers_)
    print("簇内平方和 (inertia):", kmeans.inertia_)

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值