参见http://blog.youkuaiyun.com/u014568921/article/details/45197027
// meanshift-cluster.cpp : 定义控制台应用程序的入口点。
//
#include "stdafx.h"
#include<iostream>
#include<vector>
#include<assert.h>
#include<cstdlib>
#include<time.h>
using namespace std;
#define MSTYPE double
class meanshift
{
private:
struct MSData
{
vector<MSTYPE>data;
//unsigned int dim;
MSData(unsigned int d)
{
//dim = d;
data.resize(d);
}
};
vector<MSData>dataset;
double kernel_bandwidth;
MSData shiftvec(MSData vec)
{
MSData shiftvector(vec.data.size());
double total_weight = 0;
for (int i = 0; i<dataset.size(); i++){
MSData temp = dataset[i];
double distance = euclidean_distance(vec, temp);
double weight = gaussian_kernel(distance);
for (int j = 0; j<shiftvector.data.size(); j++){
shiftvector.data[j] += temp.data[j] * weight;
}
total_weight += weight;
}
for (int i = 0; i<shiftvector.data.size(); i++){
shiftvector.data[i] /= total_weight;
}
return shiftvector;
}
double gaussian_kernel(double distance){
double temp = exp(-(distance*distance) / (kernel_bandwidth));
return temp;
}
double euclidean_distance(const MSData &data1, const MSData &data2)
{
assert(data1.data.size() == data2.data.size());
double sum = 0;
for (int i = 0; i<data1.data.size(); i++){
sum += (data1.data[i] - data2.data[i]) * (data1.data[i] - data2.data[i]);
}
return sqrt(sum);
}
public:
meanshift(double kernel_bandwidth) :kernel_bandwidth(kernel_bandwidth)
{
time_t t;
srand(time(&t));
}
vector<MSData> apply()
{
vector<int> stop_moving;
stop_moving.resize(dataset.size());
vector<MSData> shifted_points = dataset;
double max_shift_distance;
do {
max_shift_distance = 0;
for (int i = 0; i<shifted_points.size(); i++){
if (!stop_moving[i]) {
MSData point_new = shiftvec(shifted_points[i]);
double shift_distance = euclidean_distance(point_new, shifted_points[i]);
if (shift_distance > max_shift_distance){
max_shift_distance = shift_distance;
}
#define EPSILON 0.00000001
if (shift_distance <= EPSILON) {
stop_moving[i] = 1;
}
shifted_points[i] = point_new;
}
}
printf("max_shift_distance: %f\n", max_shift_distance);
} while (max_shift_distance > EPSILON);
for (int i = 0; i < dataset.size(); i++)
{
cout << "原始坐标 (" << dataset[i].data[0] << "," << dataset[i].data[1] << ") 滑动到 ("
<< shifted_points[i].data[0] << "," << shifted_points[i].data[1] << ")" << endl;
}
return shifted_points;
}
void generatedata(int datanums,vector<int>&span)
{
for (int i = 0; i < datanums; i++)
{
MSData dd(span.size());
for (int j = 0; j < span.size(); j++)
{
dd.data[j] = double(rand()) / (RAND_MAX + 1.0)*span[j];
}
dataset.push_back(dd);
}
}
};
int _tmain(int argc, _TCHAR* argv[])
{
meanshift ms(4);
vector<int>span;
span.push_back(20);
span.push_back(20);
ms.generatedata(100, span);
ms.apply();
return 0;
}
结果如下图