参考:https://www.cnblogs.com/90zeng/p/kdtree.html
作者写的非常好,我只是改动成了我习惯的格式,稍许小改动,感谢作者
#include <iostream>
#include <vector>
#include "kd_tree.hpp"
using namespace std;
int main()
{
int data[6][2] = {{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}};
vector<vector<double> > train(6, vector<double>(2, 0));
for (unsigned i = 0; i < 6; ++i)
for (unsigned j = 0; j < 2; ++j)
train[i][j] = data[i][j];
auto* kdTree = new kd_tree;
build_kd_tree(kdTree, train, 0);
print_kd_tree(kdTree, 0);
vector<double> goal;
goal.push_back(3);
goal.push_back(4.5);
vector<double> nearest_neighbor = search_nearest_neighbor(goal, kdTree);
auto beg = nearest_neighbor.begin();
cout << "The nearest neighbor is: ";
while(beg != nearest_neighbor.end()) cout << *beg++ << ",";
cout << endl;
return 0;
}
//
// Created by gu on 12/16/2020.
//
#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
using namespace std;
class kd_tree{
public:
vector<double> root;
kd_tree* parent;
kd_tree* left_child;
kd_tree* right_child;
//默认构造函数
kd_tree() : parent(nullptr), left_child(nullptr), right_child(nullptr){}
//判断kd树是否为空
bool is_empty() const
{
return root.empty();
}
//判断kd树是否只是一个叶子结点
bool is_leaf() const
{
return (!root.empty()) && right_child == nullptr && left_child == nullptr;
}
//判断是否是树的根结点
bool is_root() const
{
return (!is_empty()) && parent == nullptr;
}
//判断该子kd树的根结点是否是其父kd树的左结点
bool is_left() const
{
return parent->left_child->root == root;
}
//判断该子kd树的根结点是否是其父kd树的右结点
bool is_right() const
{
return parent->right_child->root == root;
}
};
/*
* 转置一个矩阵,返回
* * */
template<typename T>
vector<vector<T> > transpose(vector<vector<T> > Matrix)
{
unsigned row = Matrix.size();
unsigned col = Matrix[0].size();
vector<vector<T> > Trans(col,vector<T>(row,0));
for (unsigned i = 0; i < col; i++)
{
for (unsigned j = 0; j < row; ++j)
{
Trans[i][j] = Matrix[j][i];
}
}
return Trans;
}
/*
* 查找数组的中位数
* */
template <typename T>
T find_middle_value(vector<T> vec)
{
sort(vec.begin(),vec.end());
auto pos = vec.size() / 2;
return vec[pos];
}
//构建kd树
void build_kd_tree(kd_tree* tree, vector<vector<double> > data, unsigned depth)
{
//样本的数量
unsigned samples_num = data.size();
//终止条件
if (samples_num == 0)
{
return;
}
if (samples_num == 1)
{
tree->root = data[0];
return;
}
//样本的维度
unsigned dimension = data[0].size();
vector<vector<double> > trans_data = transpose(data);
//选择切分属性
unsigned split_attribute = depth % dimension;
vector<double> split_attribute_values = trans_data[split_attribute];
//选择切分值
double splitValue = find_middle_value(split_attribute_values);
//cout << "splitValue" << splitValue << endl;
// 根据选定的切分属性和切分值,将数据集分为两个子集
vector<vector<double>> subset1;
vector<vector<double>> subset2;
for (unsigned i = 0; i < samples_num; i++)
{
if (split_attribute_values[i] == splitValue && tree->root.empty())
tree->root = data[i];
else
{
if (split_attribute_values[i] < splitValue)
subset1.push_back(data[i]);
else
subset2.push_back(data[i]);
}
}
//子集递归调用buildkd_tree函数
tree->left_child = new kd_tree;
tree->left_child->parent = tree;
tree->right_child = new kd_tree;
tree->right_child->parent = tree;
build_kd_tree(tree->left_child, subset1, depth + 1);
build_kd_tree(tree->right_child, subset2, depth + 1);
}
//逐层打印kd树
void print_kd_tree(kd_tree *tree, unsigned depth)
{
for (unsigned i = 0; i < depth; i++)
cout << "\t";
for (double j : tree->root)
cout << j << ",";
cout << endl;
if (tree->left_child == nullptr && tree->right_child == nullptr )//叶子节点
return;
else //非叶子节点
{
if (tree->left_child != nullptr)
{
for (unsigned i = 0; i < depth + 1; i++)
cout << "\t";
cout << " left:";
print_kd_tree(tree->left_child, depth + 1);
}
cout << endl;
if (tree->right_child != nullptr)
{
for (unsigned i = 0; i < depth + 1; i++)
cout << "\t";
cout << "right:";
print_kd_tree(tree->right_child, depth + 1);
}
cout << endl;
}
}
//计算空间中两个点的距离
double measure_distance(vector<double> point1, vector<double> point2, unsigned method)
{
if (point1.size() != point2.size())
{
cerr << "Dimensions don't match!!" ;
return -1;
}
switch (method)
{
case 0://欧氏距离
{
double res = 0;
for (vector<double>::size_type i = 0; i < point1.size(); i++)
{
res += pow((point1[i] - point2[i]), 2);
}
return sqrt(res);
}
case 1://曼哈顿距离
{
double res = 0;
for (vector<double>::size_type i = 0; i < point1.size(); i++)
{
res += abs(point1[i] - point2[i]);
}
return res;
}
default:
{
cerr << "Invalid method!!" << endl;
return -1;
}
}
}
//在kd tree中搜索目标点的最近邻
//输入:目标点, 已构造的kd树
//输出:目标点的最近邻
vector<double> search_nearest_neighbor(vector<double> goal, kd_tree *tree)
{
/*第一步:在kd树中找出包含目标点的叶子结点:从根结点出发,
递归的向下访问kd树,若目标点的当前维的坐标小于切分点的
坐标,则移动到左子结点,否则移动到右子结点,直到子结点为
叶结点为止,以此叶子结点为“当前最近点”
*/
unsigned k = tree->root.size(); //计算出数据的维数
unsigned d = 0; //维度初始化为0,即从第1维开始
kd_tree* current_tree = tree;
vector<double> current_nearest = current_tree->root;
while(!current_tree->is_leaf())
{
unsigned index = d % k; //计算当前维度,深度
if (current_tree->right_child->is_empty() || goal[index] < current_nearest[index])
{
current_tree = current_tree->left_child;
}
else
{
current_tree = current_tree->right_child;
}
d++;
}
current_nearest = current_tree->root;
/*第二步:递归地向上回退, 在每个结点进行如下操作:
(a)如果该结点保存的实例比当前最近点距离目标点更近,则以该例点为“当前最近点”
(b)当前最近点一定存在于某结点一个子结点对应的区域,检查该子结点的父结点的另
一子结点对应区域是否有更近的点(即检查另一子结点对应的区域是否与以目标点为球
心、以目标点与“当前最近点”间的距离为半径的球体相交);如果相交,可能在另一
个子结点对应的区域内存在距目标点更近的点,移动到另一个子结点,接着递归进行最
近邻搜索;如果不相交,向上回退*/
//当前最近邻与目标点的距离
double current_distance = measure_distance(goal, current_nearest, 0);
//如果当前子kd树的根结点是其父结点的左孩子,则搜索其父结点的右孩子结点所代表的区域,反之亦反
kd_tree* search_district;
if (current_tree->is_left())
{
if (current_tree->parent->right_child == nullptr)
search_district = current_tree;
else
search_district = current_tree->parent->right_child;
}
else
{
if (current_tree->parent->left_child == nullptr)
search_district = current_tree;
else
search_district = current_tree->parent->left_child;
}
//如果搜索区域对应的子kd树的根结点不是整个kd树的根结点,继续回退搜索
while (search_district->parent != nullptr)
{
//搜索区域与目标点的最近距离
double district_distance = abs(goal[(d + 1) % k] - search_district->parent->root[(d + 1) % k]);
//如果“搜索区域与目标点的最近距离”比“当前最近邻与目标点的距离”短,表明搜索区域内可能存在距离目标点更近的点
if (district_distance < current_distance )//&& !search_district->is_empty()
{
double parent_distance = measure_distance(goal, search_district->parent->root, 0);
if (parent_distance < current_distance)
{
current_distance = parent_distance;
current_tree = search_district->parent;
current_nearest = current_tree->root;
}
if (!search_district->is_empty())
{
double root_distance = measure_distance(goal, search_district->root, 0);
if (root_distance < current_distance)
{
current_distance = root_distance;
current_tree = search_district;
current_nearest = current_tree->root;
}
}
if (search_district->left_child != nullptr)
{
double left_distance = measure_distance(goal, search_district->left_child->root, 0);
if (left_distance < current_distance)
{
current_distance = left_distance;
current_tree = search_district;
current_nearest = current_tree->root;
}
}
if (search_district->right_child != nullptr)
{
double right_distance = measure_distance(goal, search_district->right_child->root, 0);
if (right_distance < current_distance)
{
current_distance = right_distance;
current_tree = search_district;
current_nearest = current_tree->root;
}
}
}//end if
if (search_district->parent->parent != nullptr)
{
search_district = search_district->parent->is_left()?
search_district->parent->parent->right_child:
search_district->parent->parent->left_child;
}
else
{
search_district = search_district->parent;
}
d++;
}//end while
return current_nearest;
}

本文详细介绍如何使用kd树算法对二维数据进行划分,并实现搜索目标点的最近邻功能。通过实例展示了构建、打印和搜索过程,适合理解空间数据结构和搜索算法的应用。
923






