def kmeans(data, k, max_iter=100):
n, m = data.size()
centers = data[torch.randperm(n)[:k]]
for _ in range(max_iter):
distance = torch.sum((data.unsqueeze(1) - centers.unsqueeze(0)).pow(2), 2)
label = torch.argmin(distance, 1)
new_centers = torch.zeros_like(centers)
for i in range(k):
new_centers[i] = torch.mean(data[label == i], 0)
if torch.all(new_centers == centers):
break
centers = new_centers
return centers, label
pytorch计算kmeans
最新推荐文章于 2024-10-18 12:08:45 发布