题目描述
题解
比较裸的kd-tree
每一次询问维护一个大小为m的优先队列(大根堆),查询的时候和堆顶比较
最后堆中的元素就是答案
代码
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<queue>
using namespace std;
#define N 50005
#define inf 2000000000
int n,k,T,m,root,cnt;
struct data
{
int l,r,d[6],mn[6],mx[6],dis;
bool operator < (const data &a) const
{
return a.dis>dis;
}
};
data tr[N*4],p,ans[15];
priority_queue <data> q;
void update(int x)
{
int l=tr[x].l,r=tr[x].r;
for (int i=0;i<k;++i)
{
if (l)
{
tr[x].mx[i]=max(tr[x].mx[i],tr[l].mx[i]);
tr[x].mn[i]=min(tr[x].mn[i],tr[l].mn[i]);
}
if (r)
{
tr[x].mx[i]=max(tr[x].mx[i],tr[r].mx[i]);
tr[x].mn[i]=min(tr[x].mn[i],tr[r].mn[i]);
}
}
}
int build(int l,int r,int d)
{
if (d==k) d=0;
int mid=(l+r)>>1;
for (int i=l;i<=r;++i) tr[i].dis=tr[i].d[d];
nth_element(tr+l,tr+mid,tr+r+1);
for (int i=0;i<k;++i) tr[mid].mx[i]=tr[mid].mn[i]=tr[mid].d[i];
if (l<mid) tr[mid].l=build(l,mid-1,d+1);
if (mid<r) tr[mid].r=build(mid+1,r,d+1);
update(mid);
return mid;
}
int qr(int x)
{
return x*x;
}
int dist(int now)
{
int dis=0;
for (int i=0;i<k;++i)
{
dis+=qr(max(0,p.d[i]-tr[now].mx[i]));
dis+=qr(max(0,tr[now].mn[i]-p.d[i]));
}
return dis;
}
void query(int now)
{
int dl,dr,d0=0;
for (int i=0;i<k;++i) d0+=qr(tr[now].d[i]-p.d[i]);
tr[now].dis=d0;
if (cnt<m)
{
++cnt;
q.push(tr[now]);
}
else
{
if (tr[now].dis<q.top().dis)
{
q.pop();
q.push(tr[now]);
}
}
if (tr[now].l) dl=dist(tr[now].l);
else dl=inf;
if (tr[now].r) dr=dist(tr[now].r);
else dr=inf;
if (dl<dr)
{
if (dl!=inf&&(cnt<m||dl<q.top().dis)) query(tr[now].l);
if (dr!=inf&&(cnt<m||dr<q.top().dis)) query(tr[now].r);
}
else
{
if (dr!=inf&&(cnt<m||dr<q.top().dis)) query(tr[now].r);
if (dl!=inf&&(cnt<m||dl<q.top().dis)) query(tr[now].l);
}
}
int main()
{
while (~scanf("%d%d",&n,&k))
{
memset(tr,0,sizeof(tr));
for (int i=1;i<=n;++i)
for (int j=0;j<k;++j) scanf("%d",&tr[i].d[j]);
root=build(1,n,0);
scanf("%d",&T);
while (T--)
{
for (int i=0;i<k;++i) scanf("%d",&p.d[i]);
scanf("%d",&m);cnt=0;
while (!q.empty()) q.pop();
query(root);
printf("the closest %d points are:\n",m);
for (int i=m;i>=1;--i) ans[i]=q.top(),q.pop();
for (int i=1;i<=m;++i)
for (int j=0;j<k;++j) printf("%d%c",ans[i].d[j]," \n"[j==k-1]);
}
}
}