【题目大意】
给定一棵N个节点、边上带权的无向无根树,再给出一个K,询问有多少个数对(i,j)满足i<j,且i与j两点在树上的距离小于等于K。
【数据规模】
多组测试数据,每组数据满足N≤10000,1≤边上权值≤1000,1≤K≤10^9。
【出处】
楼天城男人必做8题之一……
【分析】
很显然有个暴力算法:从每个点出发遍历整棵树,统计数对个数。
也很显然时间复杂度O(N^2),要TLE。
1、经过根节点
2、不经过根节点,也就是说在根节点的一棵子树中
对于情况2,可以递归求解,所以只要考虑情况1就行了。
我们要求的是满足i<j且dist[i]+dist[j]<=K且经过根节点的数对(i,j)的个数,但是直接不好求,于是我们:
设Y为满足i<j且dist[i]+dist[j]<=K且Lca(i,j)=i或Lca(i,j)=j的数对(i,j)的个数
那么我们要统计的量便等于X-Y
求X、Y的过程均可以转化为以下问题:
已知A[1],A[2],...A[m],求满足i<j且A[i]+A[j]<=K的数对(i,j)的个数
对于这个问题,我们先将A从小到大排序。
设B[i]表示满足A[i]+A[p]<=K的最大的p(若不存在则为0)。我们的任务便转化为求出A所对应的B数组。那么,若B[i]>i,那么i对答案的贡献为B[i]-i。
显然,随着i的增大,B[i]的值是不会增大的。利用这个性质,我们可以在线性的时间内求出B数组,从而得到答案。
综上,设递归最大层数为L,因为每一层的时间复杂度均为“瓶颈”——排序的时间复杂度O(NlogN),所以总的时间复杂度为O(L*NlogN)
然而,如果遇到极端情况——这棵树是一根链,那么随意分割势必会导致层数达到O(N)级别,对于N=10000的数据是无法承受的。
于是我们将递归改为点的分治。
我们在每一棵子树中选择“最优”的点分割。所谓“最优”,是指删除这个点后最大的子树尽量小,也就是树的重心。
重心可以在线性时间复杂度内解决,不会增加时间复杂度。这样一来,即使是遇到一根链的情况时,L的值也仅仅是O(logN)的。
因此,改进后算法时间复杂度为O(Nlog^2N),可以AC。
【代码】
//Ciocio's Code
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <queue>
using namespace std;
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define rrep(i,b,a) for(int i=(b);i>=(a);--i)
#define MAXN 10010
#define INF (~0U>>1)
#define sf scanf
#define pf printf
struct node{
int v,w;
int sum,maxv;
node* next;
}*head[MAXN],tree[MAXN<<1],tp[MAXN];
int N,K,tot,totEdge,ans;
bool vis[MAXN];
int hash[MAXN];
int size[MAXN];
int dist[MAXN];
void Clear(){
ans=totEdge=0;
memset(head,0,sizeof head);
memset(vis,0,sizeof vis);
}
void Addedge(int u,int v,int w){
tree[totEdge].v=v;
tree[totEdge].w=w;
tree[totEdge].next=head[u];
head[u]=&tree[totEdge++];
}
void Init(){
int u,v,w;
rep(i,1,N-1){
sf("%d%d%d",&u,&v,&w);
Addedge(u,v,w);
Addedge(v,u,w);
}
}
void Dfs(int s,int fa){
tp[s].sum=1;
tp[s].maxv=0;
for(node* p=head[s];p!=NULL;p=p->next)
if(p->v!=fa&&!vis[p->v]){
Dfs(p->v,s);
tp[s].sum+=tp[p->v].sum;
tp[s].maxv=max(tp[s].maxv,tp[p->v].sum);
}
hash[tot]=s;
size[tot++]=tp[s].maxv;
}
int Getroot(int s){
tot=0;
Dfs(s,0);
int tmp=INF,maxr,cnt=tp[s].sum;
rep(i,0,tot-1){
size[i]=max(size[i],cnt-size[i]-1);
if(size[i]<tmp){
tmp=size[i];
maxr=hash[i];
}
}
return maxr;
}
void Getdis(int s,int fa,int dis){
dist[tot++]=dis;
for(node* p=head[s];p!=NULL;p=p->next)
if(p->v!=fa&&!vis[p->v]&&dis+p->w<=K)
Getdis(p->v,s,dis+p->w);
}
void Plus(int s){
sort(dist,dist+tot);
int left=0,right=tot-1;
while(left<right){
if(dist[left]+dist[right]<=K)
ans+=right-left,left++;
else right--;
}
}
void Sub(int s){
for(node* p=head[s];p!=NULL;p=p->next)
if(!vis[p->v]){
tot=0;
Getdis(p->v,s,p->w);
sort(dist,dist+tot);
int left=0,right=tot-1;
while(left<right){
if(dist[left]+dist[right]<=K)
ans-=right-left,left++;
else right--;
}
}
}
void Solve(int s,int fa){
int root=Getroot(s);
vis[root]=true;
tot=0;
Getdis(root,0,0);
Plus(root);
Sub(root);
for(node* p=head[root];p!=NULL;p=p->next)
if(p->v!=fa&&!vis[p->v])
Solve(p->v,root);
}
int main(){
while((sf("%d%d",&N,&K)!=EOF)&&N+K){
Clear();
Init();
Solve(1,0);
pf("%d\n",ans);
}
return 0;
}