Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.
Sample Input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
Sample Output
8
Source
LouTiancheng@POJ
Solution
跨年赛的一道题,树分治入门题(?)
推荐看集训队论文 漆子超:分治算法在树的路径问题中的应用
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#define Max(a,b) (a>b?a:b)
using namespace std;
int n,k,root,cnt,ans,num,nodenum,head[10005],deep[10005],maxv[10005],size[10005];
bool visited[10005];
struct Node
{
int next,to,w;
}edges[20010];
void add(int u,int v,int w)
{
edges[cnt].next=head[u];
edges[cnt].to=v;
edges[cnt].w=w;
head[u]=cnt;cnt++;
}
void get_root(int u,int f)
{
size[u]=1;
maxv[u]=0;
for(int i=head[u];~i;i=edges[i].next)
{
if(edges[i].to==f||visited[edges[i].to])continue;
get_root(edges[i].to,u);
size[u]+=size[edges[i].to];
maxv[u]=Max(maxv[u],size[edges[i].to]);
}
maxv[u]=Max(maxv[u],nodenum-size[u]);
if(maxv[u]<maxv[root])root=u;
}
void get_deep(int u,int f,int now)
{
deep[++num]=now;
for(int i=head[u];~i;i=edges[i].next)
{
if(edges[i].to==f||visited[edges[i].to])continue;
get_deep(edges[i].to,u,now+edges[i].w);
}
}
int calc(int u,int now)
{
int res=0;
num=0;
get_deep(u,0,now);
sort(deep+1,deep+1+num);
for(int l=1,r=num;l<r;)
{
if(deep[l]+deep[r]<=k){res+=r-l;l++;}
else r--;
}
return res;
}
void work(int u)
{
visited[u]=1;
ans+=calc(u,0);
for(int i=head[u];~i;i=edges[i].next)
{
if(visited[edges[i].to])continue;
ans-=calc(edges[i].to,edges[i].w);
root=0;
nodenum=size[edges[i].to];
get_root(edges[i].to,u);
work(root);
}
}
int main()
{
while(~scanf("%d%d",&n,&k)&&n!=0)
{
cnt=0;ans=0;root=0;maxv[0]=0x3f3f3f3f;
memset(size,0,sizeof(size));
memset(visited,0,sizeof(visited));
memset(head,-1,sizeof(head));
int u,v,w;
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&u,&v,&w);
add(u,v,w);
add(v,u,w);
}
get_root(1,0);
work(root);
printf("%d\n",ans);
}
return 0;
}