KD树(K-dimension tree),个人理解就是一棵扩展到k维空间上的二叉搜索树,它在每一层用不同的维度进行二分,将空间中的点按一定的顺序组织成二叉树的形式,以方便搜索,具体的内容就不在这里细述了。
来看看这道题吧,题目大意:给定K(<=5)维空间中的n(<=5000)个点,每个坐标的绝对值不超过10000,做t次(<=10000)查询,询问距离指定坐标最近的m个点(m<=10)
直接暴搜的话肯定会超时,这里用KD Tree存储这些点,然后剪枝搜索
这道题有个大坑,数据是多组的,“There are multiple test cases. Process to end of file.”,样例又没有体现,对我这种不认真读题的真是噩梦一般.......
先来思考这么一个问题,如何查找距离某坐标最近的节点?可以根据每一层的划分依据进入左子树或右子树查找,维护一个最小距离,这样当遍历到叶节点的时候我们就能得到当前搜索范围中的最近点,当然这不一定是最近的节点,之所以优先遍历这一条路径是为了引发更多的剪枝。在一层层的向上返回时再去进入另一颗子树查找
那么如何剪枝呢?对于某节点所处的一层来说,这个节点在这个维度上将空间二分了,可以看成一条数轴,如果以目标点为中点、当前最小距离为半径的区间包不包含含该节点,那么显然最短距离只可能落在目标节点一层,而这一侧一定是递归搜索时先进入的子树,再返回到该节点时,发现条件不满足,那么就没有必要进行另一棵子树的搜索了。
这样就完成了找到最近的一个点的任务,那么如何找到最近的M个点?这个就简单了,上面的问题中只要保存找到的最近点就行了,现在保存当前找到的最近M个点的集合,维护最小距离时维护的是当前搜索的节点中第M短的长度,如果集合中元素不足M个,就直接将点加入集合中。具体实现上可以用最大堆来维护,不过因为这道题M<=10,普通的插入排序效率也够了
思想就是这些,下面是代码(还是老样子,写的比较难看....不知为啥一时脑子进水用了索引排序==)
#include <iostream>
#include <cstdio>
#include <cmath>
#define MAXN 50010
#define SQ(X) (X)*(X)
using namespace std;
int data[MAXN][12]; // 点数据
int index[MAXN]; // 点的索引,用于分割
int ans[MAXN]; // 存放查找到的点
struct treeNode
{
int value; // 代表点的索引
int splitIndex; // 分割的维度
treeNode * left;
treeNode * right;
// treeNode(int _val, int _sIndex): value(_val), splitIndex(_sIndex), left(NULL), right(NULL) {}
}trNode[MAXN];
struct KDtree
{
int dimension;
int cnt; // 树中的节点
treeNode * root; // 根节点
treeNode * treeArray[15]; // 保存当前维护的最近m个节点
int dist[15]; // 当前最近m个节点的距离
int top; // 当前维护的集合中的节点个数
KDtree(int k): dimension(k)
{
cnt = 0;
top = 0;
}
int partition(int * dataIndex, int s, int e, int splitIndex) // 将data中的数据根据splitIndex,做一次分割,返回分割归位的位置
{
int tmp = dataIndex[e - 1];
int i = s, j = e - 1;
while (i < j)
{
//printf("%d %d\n", i, j);
while (i < j && data[dataIndex[i]][splitIndex] <= data[tmp][splitIndex])
i++;
if (i == j)
break;
dataIndex[j] = dataIndex[i];
j--;
while (i < j && data[dataIndex[j]][splitIndex] >= data[tmp][splitIndex])
j--;
if (i == j)
break;
dataIndex[i] = dataIndex[j];
i++;
}
dataIndex[i] = tmp;
return i;
}
treeNode * build(int * dataIndex, int s, int e, int d) // 递归建树
{
//printf("%d %d\n", s, e);
if (s >= e)
return NULL;
int splitIndex = d % dimension; // 据说选择划分的维度根据方差效果是最优的,偷懒了> <
int mid = (s + e) >> 1, tmp = e, ptCnt = 0; // 建树的过程中实际上要找到中位数,可以用STL中的nth_element,我这个写的貌似是卡特殊数据了,贡献了无数次TLE....不过对于这道题来说任意的划分一次貌似效果也不错
/* while (tmp != mid && ptCnt < 5)
{
//printf("%d\n", tmp);
if (tmp > mid)
tmp = partition(dataIndex, s, tmp, splitIndex);
else
tmp = partition(dataIndex, tmp + 1, e, splitIndex);
ptCnt++;
}*/
tmp = partition(dataIndex, s, e, splitIndex); // 任意分割一次
trNode[cnt].value = dataIndex[tmp];
trNode[cnt].splitIndex = splitIndex;
trNode[cnt].left = NULL;
trNode[cnt].right = NULL;
treeNode * tmpTreeNode = &trNode[cnt++];
//treeNode * tmpTreeNode = new treeNode(dataIndex[tmp], splitIndex);
tmpTreeNode->left = build(dataIndex, s, tmp, d + 1);
tmpTreeNode->right = build(dataIndex, tmp + 1, e, d + 1);
return tmpTreeNode;
}
void buildKD(int n, int * dataIndex) // 根据索引数组构建KDtree
{
root = build(dataIndex, 0, n, 0);
}
int distance(int * a, int * b) // 返回两个点之间的距离
{
int ans = 0;
for (int i = 0; i < dimension; ++i)
{
ans += (a[i] - b[i]) * (a[i] - b[i]);
}
return ans;
}
void find(int * pt, treeNode * root, int m) // 递归查找节点
{
if (root == NULL)
return;
int splitIndex = root->splitIndex;
treeNode * another; // 未选择的分支
int tmpDist = distance(pt, data[root->value]);
int p = top; // 维护当前距离最近的m个点的数组
while (p > 0 && tmpDist < dist[p - 1]) // 插入当前节点
{
treeArray[p] = treeArray[p - 1];
dist[p] = dist[p - 1];
p--;
}
treeArray[p] = root;
dist[p] = tmpDist;
if (top < m)
++top;
if (pt[splitIndex] < data[root->value][splitIndex]) // 递归搜索
{
another = root->right;
find(pt, root->left, m);
}
else
{
another = root->left;
find(pt, root->right, m);
}
if (top < m || SQ(pt[splitIndex] - data[root->value][splitIndex]) < dist[top - 1]) // 如果另一个分支可能存在最优解继续搜索
find(pt, another, m);
}
void findML(int * pt, int m, int * ansArray)
{
top = 0;
find(pt, root, m);
for (int i = 0; i < m; ++i)
{
ansArray[i] = treeArray[i]->value;
}
return;
}
};
int main()
{
int n, m, k, t;
int pt[16];
while (scanf("%d%d", &n, &k) == 2)
{
KDtree kd(k);
for (int i = 0; i < n; ++i)
{
for (int j = 0; j < k; ++j)
{
scanf("%d", data[i] + j);
}
index[i] = i;
}
kd.buildKD(n, index);
scanf("%d", &t);
while (t--)
{
for (int i = 0; i < k; ++i)
{
scanf("%d", pt + i);
}
scanf("%d", &m);
kd.findML(pt, m, ans);
printf("the closest %d points are:\n", m);
for (int i = 0; i < m; ++i)
{
int tmpIndex = ans[i];
printf("%d", data[tmpIndex][0]);
for (int j = 1; j < k; ++j)
{
printf(" %d", data[tmpIndex][j]);
}
printf("\n");
}
}
}
return 0;
}
205

被折叠的 条评论
为什么被折叠?



