【0】README
0.1)本文描述和源代码均为原创,旨在说明 如何将模板方法模式应用到kmean聚类算法;
0.2)模板方法模式的intro, 参见 模板方法模式
0.3)for kmeans alg source code, please visit kmeans&templateMethodPattern
【1】intro to kmeans
1.1)准备工作:随机初始化聚类质心(随机更新质心),读取数据等;
1.2)核心算法:
for i in round
step1)为每个样本item分配质心;
step2)更新质心;
step3)打印每轮聚类结果(可选)或重定向到 持久化文件;
public final void cluster() {// 聚类方法
randomRefineInitialCentroid(); // 随机初始化聚类质心
for (int i = 0; i < ClusterParam.clusterExeTimes; i++) {
Arrays.fill(ClusterData.clusterMemberNum, 0); // 清空聚类成员个数记录
for (int j = 0; j < ClusterData.rowno; j++){
ClusterData.rownoWithClusterid[j] = assign(j);
}
clearDoubleArray(ClusterData.centroid); // 清空质心
refine();
PrintResult.print(i+1);
}
}
【2】将模板方法模式应用到kmean聚类算法
2.1)intro to 模板方法模式: 一句话说完,模板方法模式就是为 封装算法而生的,主要是在基类 先对 算法的steps 给出 outline, 然后 抽取各个子类的共同 steps 到 基类进行具体实现,其他的都抽象为抽象方法 有子类实现;这样一来,整个alg 的 steps 无比清晰,且易于扩展,特别是对于学术型alg,有很多变体算法,如基于kmeans的 聚类算法就有很多;
2.2) kmeans 聚类算法基类和子类
package com.research.alg2;
import static java.lang.System.out;
import java.util.Arrays;
import com.research.io2.PrintResult;
import com.research.pojo2.ClusterData;
import com.research.pojo2.ClusterParam;
public abstract class ClusterAlg {
public static String algName;
abstract int assign(int index); // 为 第 index 个item 分配质心,返回结果是质心编号index
abstract void refine();// 精炼质心
public final void cluster() {// 聚类方法
randomRefineInitialCentroid(); // 随机初始化聚类质心
for (int i = 0; i < ClusterParam.clusterExeTimes; i++) {
Arrays.fill(ClusterData.clusterMemberNum, 0); // 清空聚类成员个数记录
for (int j = 0; j < ClusterData.rowno; j++){
ClusterData.rownoWithClusterid[j] = assign(j);
}
clearDoubleArray(ClusterData.centroid); // 清空质心
refine();
PrintResult.print(i+1);
}
}
// reset centroid array zeros
final void clearIntArray(int[][] data) {
for (int i = 0; i < data.length; i++)
Arrays.fill(data[i], 0);
}
// reset centroid array zeros
final void clearDoubleArray(double[][] data) {
for (int i = 0; i < data.length; i++)
Arrays.fill(data[i], 0);
}
// randomly update or refine init centroids
final void randomRefineInitialCentroid() {
int[] initCentorid = generateRandom(ClusterData.rowno, ClusterParam.clusterNum);
System.out.println("==== init centroids are as follows:");
for (int i = 0; i < initCentorid.length; i++){
ClusterData.centroid[i] = ClusterData.items[initCentorid[i]].clone();
out.printf("%-8s", "item" + initCentorid[i]);
}
out.printf("\n============================================================================\n");
}
/**
* fabricate random array
* @param volumn , random number upper limit
* @param interval , interval number and there is a random number in every interval
* @return a random array
*/
final int[] generateRandom(int volume, int interval) {
int[] r_data = new int[interval];
int intervalVolume = volume / interval;
for (int i = 0; i < interval; i++) {
int r = (int) (Math.random() * intervalVolume);
r_data[i] = r + intervalVolume * i;
}
// r_data[0] = 1;
// r_data[1] = 101;
// r_data[2] = 301;
// r_data[3] = 501;
// r_data[4] = 701;
// r_data[5] = 901;
return r_data;
}
}
package com.research.alg2;
import static java.lang.Math.pow;
import com.research.pojo2.ClusterData;
import com.research.pojo2.ClusterParam;
public class KmeansAlg extends ClusterAlg {
/**
* compute the centroid the item should be assigned to
* @param index refers to item index
* @return cluster id whose has the smallest distance between centroid and
* the item
*/
@Override
int assign(int index) {
double sum = 0;
double miniSum = 0;
int miniIndex = 0;
double[] item = ClusterData.items[index];
for (int i = 0; i < ClusterData.dimension; i++)
sum += pow(item[i] - ClusterData.centroid[0][i], 2.0);
miniSum = sum;
for (int i = 1; i < ClusterParam.clusterNum; i++) {
sum = 0;
for (int j = 0; j < ClusterData.dimension; j++)
sum += Math.pow(item[j] - ClusterData.centroid[i][j], 2.0);
if (miniSum > sum) {
miniSum = sum;
miniIndex = i;
}
}
ClusterData.clusterMemberNum[miniIndex]++;
return miniIndex;
}
@Override
void refine() {
int clusterId;
for (int i = 0; i < ClusterData.rowno; i++) {
clusterId = ClusterData.rownoWithClusterid[i];
for (int j = 0; j < ClusterData.dimension; j++)
ClusterData.centroid[clusterId][j] += ClusterData.items[i][j];
}
// update centroids(refinement procedure)
for (int i = 0; i < ClusterParam.clusterNum; i++)
for (int j = 0; j < ClusterData.dimension; j++)
ClusterData.centroid[i][j] /= ClusterData.clusterMemberNum[i];
}
}