package com.i9i.rpc;
import java.awt.Color;
import java.awt.Graphics;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.Timer;
import java.util.TimerTask;
import java.util.TreeMap;
import java.util.concurrent.TimeUnit;
import javax.swing.JFrame;
import javax.swing.SwingUtilities;
public class Kmean extends JFrame {
/***
*
* @param centerPoints 中心点
* @param data 输入样本
* @return
*/
public static Map<Point, List<Point>> getPointsGroup(List<Point> centerPoints, List<Point> data) {
Map<Point, List<Point>> kdata = new TreeMap<Point, List<Point>>();
for (Point jp : centerPoints) {
kdata.put(jp, new ArrayList<Point>());
}
for (Point ip : data) {
Point spoint = null;
double lastdistance = 0;
double currdistance = 0;
for (Point jp : centerPoints) {
currdistance = getDistince(jp, ip);
if (spoint == null || lastdistance > currdistance) {
lastdistance = currdistance;
spoint = jp;
}
}
kdata.get(spoint).add(ip);
}
return kdata;
}
/******
* 依据分类,获取新的中心点
*
* @param mpoints 分类键对
* @return
*/
public static List<Point> getCenterPoints(Map<Point, List<Point>> mpoints) {
List<Point> newPoints = new ArrayList<Point>();
Set<Point> key = mpoints.keySet();
for (Point p : key) {
List<Point> lpoints = mpoints.get(p);
double sumx = 0, sumy = 0;
for (Point lpoint : lpoints) {
sumx += lpoint.getX();
sumy += lpoint.getY();
}
Point newcenterPoint = new Point();
newcenterPoint.setX(sumx / lpoints.size());
newcenterPoint.setY(sumy / lpoints.size());
newPoints.add(newcenterPoint);
}
return newPoints;
}
public static List<Point> initData(int n) {
List<Point> list = new ArrayList<Point>();
java.util.Random rd = new Random(3000);
for (int i = 0; i < n; i++) {
Point p = new Point();
p.setX(rd.nextDouble() * 1000);
p.setY(rd.nextDouble() * 1000);
list.add(p);
}
return list;
}
public static void printPair(Map<Point, List<Point>> cpoints) {
Set<Point> key = cpoints.keySet();
for (Point p : key) {
System.out.print(p + " :{");
List<Point> list = cpoints.get(p);
for (Point lp : list) {
System.out.print(lp + ",");
}
System.out.println("}");
}
}
public static void printResult(List<Map<Point, List<Point>>> cpoints) {
int i = 0;
for (Map<Point, List<Point>> mp : cpoints) {
System.out.println("第" + i + "计算分类信息");
Set<Point> key = mp.keySet();
for (Point p : key) {
System.out.print(p + " :{");
List<Point> list = mp.get(p);
for (Point lp : list) {
System.out.print(lp + ",");
}
System.out.println("}");
}
i++;
}
}
public static double getDistince(Point p1, Point p2) {
return Math.sqrt(
(p1.getX() - p2.getX()) * (p1.getX() - p2.getX()) + (p1.getY() - p2.getY()) * (p1.getY() - p2.getY()));
}
public static double sumDistince(Point p1, List<Point> ps) {
double sum = 0;
for (Point p : ps) {
sum = sum + getDistince(p1, p);
}
return sum;
}
List<Point> data = initData(500);
List<Point> kpoint = initData(8);
List<List<Point>> kpoints = new ArrayList<List<Point>>();
Map<Point, List<Point>> result = new TreeMap<>();
// List<Map<Point, List<Point>>> cpoints = new ArrayList<Map<Point, List<Point>>>();
Kmean() {
// this.resize(1000, 1000);
this.setVisible(true);
this.setSize(1000, 1000);
}
public void cal() throws InterruptedException {
double lastDistince = 0;
double currDistince = 0;
long count = 1;
double minDistince = Long.MAX_VALUE;
while (true) {
System.out.println("正在进行" + count + "次迭代");
Map<Point, List<Point>> pair = getPointsGroup(kpoint, data);
result = pair;
printPair(pair);
// cpoints.add(pair);
Set<Point> key = pair.keySet();
lastDistince = currDistince;
currDistince = 0;
for (Point p : key) {
currDistince += sumDistince(p, pair.get(p));
}
if (currDistince < minDistince) {
minDistince = currDistince;
}
this.update(this.getContentPane().getGraphics());
TimeUnit.MILLISECONDS.sleep(1800);
count++;
if (lastDistince <= currDistince
&& (currDistince < 0.001 || Math.abs(lastDistince - currDistince) / lastDistince < 0.0001))
break;
else
kpoint = getCenterPoints(pair);
}
}
/**
* @param args
* @throws InterruptedException
*/
public static void main(String[] args) throws InterruptedException {
Kmean demo = new Kmean();
demo.cal();
// printResult(cpoints);
}
public void clear(Graphics g) {
g.setColor(getBackground());
g.fillRect(0, 0, getWidth(), getHeight());
paint(g);
}
@Override
public void update(Graphics g) {
super.update(g);
clear(g);
Color[] color = { Color.BLUE, Color.RED, Color.yellow, Color.GREEN, Color.PINK, Color.ORANGE, Color.MAGENTA,
Color.BLACK };
int i = 0;
int r = 10;
for (Point key : result.keySet()) {
g.setColor(color[i]);
g.fillArc(key.getX().intValue(), key.getY().intValue(), r * 2, r * 2, 0, 360);
List<Point> points = result.get(key);
for (Point p : points) {
g.fillArc(p.getX().intValue(), p.getY().intValue(), r, r, 0, 360);
}
i++;
}
}
static class Point implements Comparable<Point> {
Double x, y;
public Double getX() {
return x;
}
public void setX(double x) {
this.x = x;
}
public Double getY() {
return y;
}
public void setY(double y) {
this.y = y;
}
@Override
public String toString() {
return "(" + x + ", " + y + ")";
}
@Override
public int compareTo(Point o) {
return this.getX().intValue() - o.getX().intValue() + this.getY().intValue() - o.getY().intValue();
}
}
}
ui版kmeans算法
于 2023-10-18 21:39:41 首次发布