需求来源
近期有个需求,需要对数据进行聚类处理,这里做个记录。
需求分析
1、了解一下聚类处理算法
2、数据构建
3、聚类处理
4、获取结果
实现方案
1、简单说一下k-means聚类算法:
2、数据构建
如果仅对数据做聚类处理,源数据使用数组或者集合来处理更方便
3、欧氏距离算法:
代码实现
代码如下,可以做个参考:
package com.***.***.**;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import com.alibaba.fastjson.JSON;
import net.sf.json.JSONArray;
import org.apache.commons.lang3.math.NumberUtils;
public class TestOSJL {
/**
* 安全随机数
*/
private static final SecureRandom RANDOM = new SecureRandom();
public static void main(String[] args) {
// 源数据
// 如果仅对数据做聚类处理,源数据使用数组或者集合来处理更方便
Map<String, double[]> dataMap;
// 数据获取
dataMap = makeData(new JSONArray());
// 定义维度
int dimensions = 3;
// 簇值:k
int k = 3;
// 聚类处理
kMeans(dataMap, k, dimensions);
}
/**
* 对获取的数据进行校验和数据组装
*
* @param dataList xx数据集合
* @return 结果
*/
private static Map<String, double[]> makeData(JSONArray dataList) {
// if (dataList.isEmpty()) {
if (false) {
System.out.println("xx数据集合为空!");
return new HashMap<>();
}
else {
// 返回结果
Map<String, double[]> resulMap = new HashMap<>();
// 初始化值
double[] a = {
10d, 6d, 16d
};
resulMap.put("a", a);
double[] b = {
15d, 7d, 18d
};
resulMap.put("b", b);
double[] c = {
20d, 8d, 22d
};
resulMap.put("c", c);
double[] d = {
25d, 9d, 25d
};
resulMap.put("d", d);
double[] e = {
30d, 10d, 19d
};
resulMap.put("e", e);
// 数据校验
System.out.println("数据校验!确保每个点都有数据!");
return resulMap;
}
}
/**
* 聚类处理方法
*
* @param dataMap 数据集合
* @param k 簇值
* @param dimensions 维度值
*/
private static void kMeans(Map<String, double[]> dataMap, int k, int dimensions) {
// 质心集合
List<double[]> centroidList;
// 分组集合
List<Map<String, double[]>> groupList = new ArrayList<>();
int count = 1;
do {
System.out.println("--------第" + count + "次聚类-------");
if (count == 1) {
// 首次聚类
// 随机获取k个质心
centroidList = getCentroidList(dataMap, k);
System.out.println("第" + count + "次聚类质心集合为:" + JSON.toJSON(centroidList));
// 获取分组数据
groupList = euDistance(dataMap, centroidList, k);
System.out.println("第" + count + "次聚类分组集合为:" + JSON.toJSON(groupList));
count++;
}
else {
// 获取新的质心集合
centroidList = getNewCentroidList(groupList, dimensions);
System.out.println("第" + count + "次聚类质心集合为:" + JSON.toJSON(centroidList));
// 获取新的分组
List<Map<String, double[]>> newGroupList = euDistance(dataMap, centroidList, k);
System.out.println("第" + count + "次聚类分组集合为:" + JSON.toJSON(newGroupList));
// 比较新旧分组是否相同
if (newGroupList.containsAll(groupList)) {
System.out.println("第" + count + "次聚类质心未发生移动,跳出循环");
break;
}
else {
// 将新组值赋给旧组
groupList = newGroupList;
}
}
}
while (true);
System.out.println("最终质心集合为:" + JSON.toJSON(centroidList));
System.out.println("最终分组结果为:" + JSON.toJSON(groupList));
}
/**
* 初始化分组数据
*
* @param k 簇值
* @return 结果
*/
private static List<Map<String, double[]>> getGroupList(int k) {
List<Map<String, double[]>> groupList = new ArrayList<>();
for (int i = 0; i < k; i++) {
Map<String, double[]> map = new HashMap<>();
groupList.add(map);
}
return groupList;
}
/**
* 获取新的质心集合
*
* @param newGroupList 新的分组数据
* @param dimensions 维度值
* @return 结果
*/
private static List<double[]> getNewCentroidList(List<Map<String, double[]>> newGroupList, int dimensions) {
List<double[]> newCentroidList = new ArrayList<>();
System.out.println("当前所有的分组的值为:" + JSON.toJSON(newGroupList));
for (int i = 0; i < newGroupList.size(); i++) {
Map<String, double[]> groupMap = newGroupList.get(i);
// 当前分组的个数
int num = groupMap.size();
// 用数组存储指标值集合
double[] indexValueList = new double[dimensions];
for (Map.Entry<String, double[]> entry : groupMap.entrySet()) {
double[] value = entry.getValue();
for (int a = 0; a < value.length; a++) {
indexValueList[a] += value[a];
}
}
double[] newCentroid = Arrays.stream(indexValueList).map(value -> value / num).toArray();
newCentroidList.add(newCentroid);
System.out
.println("第" + (i + 1) + "组的数据为:" + JSON.toJSON(groupMap) + ",其新的质心为:" + JSON.toJSON(newCentroid));
}
return newCentroidList;
}
/**
* 获取分组数据
*
* @param dataMap 数据集合
* @param centroidList 质心集合
* @param k 簇值
* @return 结果
*/
private static List<Map<String, double[]>> euDistance(Map<String, double[]> dataMap, List<double[]> centroidList,
int k) {
List<Map<String, double[]>> groupList = getGroupList(k);
for (Map.Entry<String, double[]> entry : dataMap.entrySet()) {
String key = entry.getKey();
double[] value = entry.getValue();
// 获取最小距离并且定位到分组的下标
Map<String, String> numAndDistance = getMinDistanceAndGroupNum(value, centroidList);
groupList.get(NumberUtils.toInt(numAndDistance.get("groupNum"))).put(key, value);
System.out.println(key + "的指标为:" + JSON.toJSON(value) + ",其到质心组的最小距离为:" + numAndDistance.get("minDistance")
+ ",将其分配至:" + (NumberUtils.toInt(numAndDistance.get("groupNum")) + 1) + "组");
}
return groupList;
}
/**
* 获取最小距离和分组下标
*
* @param value 需要处理的数据
* @param centroidList 质心集合
* @return 结果
*/
private static Map<String, String> getMinDistanceAndGroupNum(double[] value, List<double[]> centroidList) {
Map<String, String> numAndDistance = new HashMap<>();
// 最小距离
double minDistance = 0d;
// 定位到哪一组(默认第一组)
int groupNum = 0;
for (int i = 0; i < centroidList.size(); i++) {
if (i == 0) {
minDistance = euclideanDistance(value, centroidList.get(i));
}
else {
double distance = euclideanDistance(value, centroidList.get(i));
// 获取小的值
if (distance < minDistance) {
// 将小值赋给minDistance
minDistance = distance;
groupNum = i;
}
}
}
numAndDistance.put("minDistance", minDistance + "");
numAndDistance.put("groupNum", groupNum + "");
return numAndDistance;
}
/**
* 获取质心数据集合
*
* @param dataMap 原数据集合
* @param k 簇值
* @return 结果
*/
private static List<double[]> getCentroidList(Map<String, double[]> dataMap, int k) {
// dataMap的key和size一一对应
Map<Integer, String> randomMap = new HashMap<>();
int size = 0;
for (Map.Entry<String, double[]> entry : dataMap.entrySet()) {
randomMap.put(size, entry.getKey());
size++;
}
List<double[]> resultList = new ArrayList<>();
// 获取随机数
int[] result = new int[k];
for (int i = 0; i < result.length;) {
result[i] = RANDOM.nextInt(dataMap.size());
if (verifyRandom(result, result[i], i)) {
resultList.add(dataMap.get(randomMap.get(result[i])));
i++;
}
}
return resultList;
}
/**
* 计算两点之间的欧氏距离
*
* @param point1 点1数据
* @param point2 点2数据
* @return 结果
*/
public static double euclideanDistance(double[] point1, double[] point2) {
double sum = 0.0;
for (int i = 0; i < point1.length; i++) {
sum += Math.pow(point1[i] - point2[i], 2);
}
return Math.sqrt(sum);
}
/**
* 检查生成的随机数是否存在与数组中
*
* @param data 返回值列表
* @param result 当前返回的值
* @param l 下标
* @return 结果
*/
public static boolean verifyRandom(int data[], int result, int l) {
for (int i = 0; i < data.length; i++) {
if (data[i] == result && l != i) {
return false;
}
}
return true;
}
}