package xxx;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.TreeSet;
public class KMeans {
public static class Point {
float x, y, z;
public Point(float x, float y, float z) {
this.x = x;
this.y = y;
this.z = z;
}
public static float getDistanceSquare(Point a, Point b) {
float sum = 0.0f;
float diff = 0.0f;
diff = a.x - b.x;
sum += diff * diff;
diff = a.y - b.y;
sum += diff * diff;
diff = a.z - b.z;
sum += diff * diff;
return sum;
}
public String toString() {
StringBuffer buffer = new StringBuffer();
buffer.append("(").append(x).append(",").append(y).append(",")
.append(z).append(")");
return new String(buffer);
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + Float.floatToIntBits(x);
result = prime * result + Float.floatToIntBits(y);
result = prime * result + Float.floatToIntBits(z);
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Point other = (Point) obj;
if (Float.floatToIntBits(x) != Float.floatToIntBits(other.x))
return false;
if (Float.floatToIntBits(y) != Float.floatToIntBits(other.y))
return false;
if (Float.floatToIntBits(z) != Float.floatToIntBits(other.z))
return false;
return true;
}
}
public static class Cluster {
Point center;
List<Point> members = new LinkedList<Point>();
public Cluster(Point center) {
this.center = center;
}
public Point getCenter() {
return center;
}
public void addMember(Point point) {
members.add(point);
}
public void reset(Point center) {
this.center = center;
members.clear();
}
public Point computeNewCenter() {
float x = 0.0f;
float y = 0.0f;
float z = 0.0f;
for (Point member : members) {
x += member.x;
y += member.y;
z += member.z;
}
x /= members.size();
y /= members.size();
z /= members.size();
return new Point(x, y, z);
}
public float sumDistanceSquares() {
float sum = 0.0f;
for (Point member : members)
sum += Point.getDistanceSquare(member, center);
return sum;
}
public String toString() {
StringBuffer buffer = new StringBuffer();
buffer.append(center).append(" --> {");
Iterator<Point> it = members.iterator();
while (it.hasNext()) {
Point member = it.next();
buffer.append(member);
if (it.hasNext())
buffer.append(", ");
}
buffer.append("}");
return new String(buffer);
}
}
public static void computeClusters(Point[] points, int K) {
if (K > points.length) {
System.out.println("聚类个数(K)不能大于待聚类的对象个数");
return;
}
List<Cluster> clusters = new ArrayList<Cluster>(K);
int[] indexes = randomSelectIndexes(points.length, K);
System.out.println("随机选取初始中心:");
for (int i = 0; i < K; i++)
System.out.println(points[indexes[i]]);
System.out.println();
for (int i = 0; i < K; i++)
clusters.add(new Cluster(points[indexes[i]]));
int iterateTime = 0;
boolean stable = false;
while (!stable) {
for (Point point : points) {
Cluster minCluster = null;
float min = Float.MAX_VALUE;
for (Cluster cluster : clusters) {
float cur = Point.getDistanceSquare(point,
cluster.getCenter());
if (minCluster == null || cur < min) {
minCluster = cluster;
min = cur;
}
}
minCluster.addMember(point);
}
iterateTime++;
float e = 0.0f;
for (Cluster cluster : clusters)
e += cluster.sumDistanceSquares();
System.out.printf("第%d次迭代后,E=%f,迭代结果:\n", iterateTime, e);
for (Cluster cluster : clusters)
System.out.println(cluster);
stable = true;
for (Cluster cluster : clusters) {
Point center = cluster.getCenter();
Point newCenter = cluster.computeNewCenter();
if (stable && !newCenter.equals(center))
stable = false;
cluster.reset(newCenter);
}
if (stable)
System.out.println("聚类已稳定,结束。");
else
System.out.println("聚类不稳定,继续...\n");
}
}
private static int[] randomSelectIndexes(int limit, int num) {
Set<Integer> selectedIndexes = new TreeSet<Integer>();
Random random = new Random();
for (int i = 0; i < num; i++) {
int index = -1;
while (index == -1 || selectedIndexes.contains(index))
index = random.nextInt(limit);
selectedIndexes.add(index);
}
int[] ret = new int[num];
int i = 0;
for (Integer index : selectedIndexes)
ret[i++] = index;
return ret;
}
public static void main(String[] args) {
Point[] points = new Point[] { new Point(13, 26, 45),
new Point(25, 63, 18), new Point(16, 22, 45),
new Point(72, 33, 18), new Point(9, 62, 14),
new Point(22, 11, 90), new Point(80, 13, 20),
new Point(78, 8, 32), new Point(55, 43, 9),
new Point(12, 31, 73) };
computeClusters(points, 3);
}
}
K-Means Java示例程序
最新推荐文章于 2024-01-27 14:59:05 发布