Tree
Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v. Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k. Write a program that will count how many pairs which are valid for a given tree. Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros. Output
For each test case output the answer on a single line.
Sample Input 5 4 1 2 3 1 3 1 1 4 2 3 5 1 0 0 Sample Output 8 Source |
[Submit] [Go Back] [Status] [Discuss]
题意:给你两个数n,k,再给你一棵树,求树上任意两点距离<=k的有多少对。由于我们队对于树上的算法不太熟悉,我特意来学习,发现还真的很多东西都不懂,这道题我也是查了很多资料才弄懂了,涉及到树的重心的概念,下面说下思路。
这道题用到了树的分治,最直观的思路是我们以某个点为根开始,求出所有点到根节点的距离然后把距离和<=k的累加起来再减去位于同一个子树的点(明显只有不同子树距离才能累加),然后把该点删去,递归到与它相邻的点,重复以上操作得出答案。这个方法可以得出答案,但是当某个根节点只有一条子链时时间复杂度很高,所以每次要以重心为根,根据重心的定义,重心的子树是分配最均匀的,这样时间复杂度就大大缩小。所以根据上面的思路改进一下就好了,每次用dfs找出重心,然后算出所有点到重心的距离,后面就和上面方法一样了,代码如下:
#include<iostream>
#include<cmath>
#include<queue>
#include<cstdio>
#include<queue>
#include<algorithm>
#include<cstring>
#include<string>
#define maxn 10005
#define inf 0x3f3f3f3f
using namespace std;
int head[maxn],vis[maxn],root,size[maxn],son[maxn],minson,dist[maxn],dislen;
int n,k,ans;
struct node{
int u,value,next;
}p[maxn<<1];
void dfssize(int x,int fa){
size[x]=1;
son[x]=0;
for(int i=head[x];~i;i=p[i].next){
int next=p[i].u;
if(next==fa||vis[next])
continue;
dfssize(next,x);
size[x]+=size[next];
if(size[next]>son[x])
son[x]=size[next];
}
}
void dfsroot(int x,int fa,int num){
if(num-size[x]>son[x])
son[x]=num-size[x];
if(son[x]<minson){
minson=son[x];
root=x;
}
for(int i=head[x];~i;i=p[i].next){
int next=p[i].u;
if(next==fa||vis[next])
continue;
dfsroot(next,x,num);
}
}
void dfsdis(int x,int fa,int len){
dist[dislen++]=len;
for(int i=head[x];~i;i=p[i].next){
int next=p[i].u;
if(next==fa||vis[next])
continue;
dfsdis(next,x,len+p[i].value);
}
}
int cal(int st,int len){
dislen=0;
dfsdis(st,0,len);
sort(dist,dist+dislen);
int i=0,j=dislen-1;
int re=0;
while(i<j){
while(dist[i]+dist[j]>k&&i<j){
j--;
}
re+=j-i;
i++;
}
return re;
}
void solve(int x){
minson=inf;
dfssize(x,0);
dfsroot(x,0,size[x]);
ans+=cal(root,0);
vis[root]=1;
for(int i=head[root];~i;i=p[i].next){
int next=p[i].u;
if(vis[next])
continue;
ans-=cal(next,p[i].value);
solve(next);
}
}
int main(){
while(~scanf("%d%d",&n,&k)&&n&&k){
memset(vis,0,sizeof(vis));
memset(head,-1,sizeof(head));
ans=0;
for(int i=0;i<((n-1)<<1);i++){
int u,v,value;
scanf("%d%d%d",&u,&v,&value);
p[i].u=u;
p[i].value=value;
p[i].next=head[v];
head[v]=i++;
p[i].u=v;
p[i].value=value;
p[i].next=head[u];
head[u]=i;
}
solve(1);
printf("%d\n",ans);
}
}//g++ -o a.exe a.cpp