题目
给出一棵带边权的树,问有多少对点的距离<=Len
第一行两个整数N,Len(2<=n<=10000,len<=maxlongint)
接下来N-1行,每行3个整数,x,y,l,表示x和y有一条边长为l的边
分析
我们可以先考虑先确定了根的树的答案,明显的是将每个点的距离算出来。
然后再排序,用线性的设两个指针来算答案。
记录i,j,然后明显的若a[j]是第一个使得a[i]+a[j]<=m的,后面的也一定小于m了。
而这时a[i+1]+a[j+1]一定是>m,因为a[i]+a[j+1]>m,所以能和i匹配的一定是<=j的。
int i=1,j=a[0].a;
while (i<j){
if (a[j].b+a[i].b>m) j--; else {
an+=j-i;i++;
}
}
然而我们可以发现这样是O(n2)
故我们继续思考如何优化。
现在我们引出一个新的算法——点分治。
其实也是确定根,然后做上面的操作。
只是我们这样做会做到很多重复的点,而且很慢。
所以我们可以每次把这棵树分成几个部分,
然后分别递归处理每棵子树的答案,就可以了。
这里大概分2个步骤:
1.计算。其实就是先用O(num),(num表示当前子树的节点个数)的时间
去把每个节点的距离算出来,然后排序,求出
dis[i]+dis[j]<=len的答案-dis[i]+dis[j]<=len且在当前的这棵树是同一棵子树的答案
排序后用线性的方法去求答案。
至于LIHUI问为什么这样是对的呢?
我来解答:
因为我们每次处理的是以规定的一个点为根,经不经过是指这个点,而不是其他的点。
这样我们便可以保证每次处理出每棵子树的答案。
2.分开。然后就可以接着找下个分割点把这棵树再分开。
注意分割点最好是这颗子树的重心,因为这样可以平均的把树分开,这样就不会被一条链卡住了
其实点分治就是利用了二分这种最基础的思想,将一颗树分成多块去考虑,这样还是用原来的处理方法,但是大幅度减少了时间。
PS:大家看代码时,可以把手打快拍改成sort,只不过我很逗比而已。。。
#include<iostream>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<cstdio>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
const int M=20010;
struct note{
int a,b;
}a[M],d[M];
int n,m,x,y,z,i,las[2*M],nex[2*M],b[2*M],v[2*M],nu,size[M],fa[M],an;
bool bz[M];
void insert(int x,int y,int z)
{
b[++nu]=y;nex[nu]=las[x];las[x]=nu;v[nu]=z;
}
void pre(int x,int y,int sum,int &ze){
int p=las[x];
size[x]=1;fa[x]=y;
bool be=true;
while (p) {
if ((b[p]!=y)&&(!bz[b[p]])) {
pre(b[p],x,sum,ze);
size[x]+=size[b[p]];
if (size[b[p]]>(sum>>1)) be=false;
}
p=nex[p];
}
if (sum-size[x]>(sum>>1)) be=false;
if (be) ze=x;
}
void add(int x,int y,int dis,int z)
{
int p=las[x];
a[++a[0].a].a=z;
a[a[0].a].b=dis;
while (p){
if ((b[p]!=y)&&(!bz[b[p]])) add(b[p],x,dis+v[p],z);
p=nex[p];
}
}
void qsort(int l,int r){
int i=l,j=r,mid=a[(l+r)>>1].b;
do{
while (a[i].b<mid) i++;
while (a[j].b>mid) j--;
if (i<=j) {
a[M].a=a[i].a;a[i].a=a[j].a;a[j].a=a[M].a;
a[M].b=a[i].b;a[i].b=a[j].b;a[j].b=a[M].b;
i++;j--;
}
}
while (i<=j);
if (l<j) qsort(l,j);
if (i<r) qsort(i,r);
}
void work(int x,int y,int sum)
{
int z=0;
pre(x,y,sum,z);
a[0].a=1;int k=0,p=las[z];
a[1].a=0;a[1].b=0;
while (p){
k++;
if ((!bz[b[p]])) add(b[p],z,v[p],k);
p=nex[p];
}
qsort(1,a[0].a);
int i=1,j=a[0].a;
while (i<j){
if (a[j].b+a[i].b>m) j--; else {
an+=j-i;i++;
}
}
fo(q,0,k){
d[0].a=0;
fo(w,1,a[0].a)
if (a[w].a==q) {
d[++d[0].a].a=a[w].a;
d[d[0].a].b=a[w].b;}
i=1;j=d[0].a;
while (i<j){
if (d[j].b+d[i].b>m) j--;else {
an-=j-i;i++;
}
}
}
bz[z]=true;
p=las[z];
while (p){
if (!bz[b[p]]) {
if (b[p]==fa[z]) work(b[p],z,sum-size[z]); else work(b[p],z,size[b[p]]);}
p=nex[p];
}
}
int main(){
scanf("%d%d",&n,&m);
fo(i,1,n-1){
scanf("%d%d%d",&x,&y,&z);
insert(x,y,z);insert(y,x,z);
}
work(1,0,n);
printf("%d",an);
}