题解
就是裸的点分治。
对于每个重心,先dfs把所有点到重心的路径的长度和代价处理好。
然后按代价排个序,首尾指针(l,r)一起扫一遍,r从右到左,把每个len[l]+len[r]≤wlen[l]+len[r]≤w的点的的路径长度都加进个BIT里,直接求即可。每一对都算了两次,最后除以个二就好了。
空间得开够啊。
代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+100;
typedef long long ll;
int n,m,L,d[N],vis[N];ll W;
int S,MX,root,l,r;
int head[N],to[N<<1],nxt[N<<1],tot;
int tre[N],sz[N],mx[N],cnt;
ll ans,dis[N],cc[N<<1];
struct P{ll ss;int len;}tp[N];
inline bool cmp(const P&x,const P&y){return x.ss<y.ss;}
inline void lk(int u,int v,int val)
{to[++tot]=v;nxt[tot]=head[u];head[u]=tot;cc[tot]=val;}
inline void ad(ll x)
{for(;x<=n+1;x+=(x&(-x))) tre[x]++;}
inline ll get(int x)
{ll re=0;for(;x;x-=(x&(-x))) re+=tre[x];return re;}
inline void clr(ll x)
{for(;x<=n+1;x+=(x&(-x))) tre[x]=0;}
inline void getrt(int x,int fa)
{
sz[x]=1;mx[x]=0;int i,j;
for(i=head[x];i;i=nxt[i]){
j=to[i];if(vis[j] || j==fa) continue;
getrt(j,x);sz[x]+=sz[j];
mx[x]=max(mx[x],sz[j]);
}
mx[x]=max(mx[x],S-sz[x]);
if(mx[x]<MX) {MX=mx[x];root=x;}
}
inline void dfs(int x,int fa)
{
cnt++;tp[cnt].ss=dis[x];tp[cnt].len=d[x];
for(int j,i=head[x];i;i=nxt[i]){
j=to[i];if(j==fa || vis[j]) continue;
dis[j]=dis[x]+cc[i];d[j]=d[x]+1;
dfs(j,x);
}
}
inline void dfss(int x,int fa)
{
cnt++;tp[cnt].ss=dis[x];tp[cnt].len=d[x];
for(int j,i=head[x];i;i=nxt[i]){
j=to[i];if(vis[j] || j==fa) continue;
dfss(j,x);
}
}
inline void calc(int op)
{
if(!cnt) return;
sort(tp+1,tp+cnt+1,cmp);
l=1;int i;ll j=0;
for(r=cnt;r>=1;--r){
for(;tp[l].ss+tp[r].ss<=W && l<=cnt;l++) ad(tp[l].len+1);
if(tp[r].len>L) continue;
j+=get(L+1-tp[r].len);if(l>r && tp[r].len<=L-tp[r].len) j--;
}
if(op==1) ans+=j/2;else ans-=j/2;
for(i=1;i<=cnt;++i) clr(tp[i].len+1);
}
inline void divide(int x)
{
int i,j;
vis[x]=1;dis[x]=0,d[x]=0;
cnt=0;dfs(x,0);
calc(1);
for(i=head[x];i;i=nxt[i]){
j=to[i];if(vis[j]) continue;
cnt=0;dfss(j,x);calc(0);sz[j]=cnt;
}
for(i=head[x];i;i=nxt[i]){
j=to[i];if(vis[j]) continue;
S=sz[j];MX=n+1;getrt(j,x);divide(root);
}
}
int main(){
int i,j,ix,iy,iz;
scanf("%d%d%I64d",&n,&L,&W);
for(i=2;i<=n;++i){
scanf("%d%d",&ix,&iy);
lk(i,ix,iy);lk(ix,i,iy);
}
S=n;MX=n+1;
getrt(1,0);
divide(root);
printf("%I64d\n",ans);
}