思路:把树剖成重链(边剖把边存在较深的一个点里去,然后查询时做特殊处理),然后建成主席树,对于每一次询问,二分每一条重链上的第k小,判断是否小于等于z。
代码:
#include<bits/stdc++.h>
using namespace std;
#define inf 0x3f3f3f3f
const int maxn=1e5+9;
int tot=0,n,son[maxn],root[maxn],d[maxn],f[maxn],size[maxn],id[maxn],di[maxn],rk[maxn],top[maxn],head[maxn],val[maxn];
struct Pt{
int val,id;
}pt[maxn],cp[maxn];
struct node{
int to,val,next;
}edge[maxn*2];
int cnt=1;
void add(int u,int v,int val){
edge[cnt].next=head[u];
edge[cnt].to=v;
edge[cnt].val=val;
head[u]=cnt++;
}
bool cmp(Pt a,Pt b){
return a.val<b.val;
}
struct tree{
int sum,ls,rs;
}tr[maxn*20];
void dfs1(int u,int pre){
f[u]=pre;
size[u]=1;
d[u]=d[pre]+1;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].to;
if(v==pre)continue;
dfs1(v,u);
val[v]=edge[i].val;
size[u]+=size[v];
if(size[v]>=size[son[u]]){
son[u]=v;
}
}
}
void dfs2(int u,int tp){
id[u]=++tot;
di[tot]=u;
top[u]=tp;
pt[tot].val=val[u];
pt[tot].id=tot;
if(son[u])dfs2(son[u],tp);
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].to;
if(v!=f[u]&&v!=son[u])dfs2(v,v);
}
}
int ft=1;
void Insert(int num,int &rt,int l,int r){
tr[ft++] = tr[rt];
rt = ft - 1;
tr[rt].sum++;
if(l==r)return;
int mid=(l+r)>>1;
if(num<=mid)Insert(num,tr[rt].ls,l,mid);
else Insert(num,tr[rt].rs,mid+1,r);
}
int query(int i,int j,int l,int r,int k){
int d=tr[tr[j].ls].sum-tr[tr[i].ls].sum;
if(l==r)return l;
int mid=(l+r)>>1;
if(k<=d)return query(tr[i].ls,tr[j].ls,l,mid,k);
else return query(tr[i].rs,tr[j].rs,mid+1,r,k-d);
}
void solve(int x,int y,int z){
int ans=0;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]])swap(x,y);
int l=1,r=id[x]-id[top[x]]+2;
while(l<r){
int mid=(l+r)>>1;
int pos=query(root[id[top[x]]-1],root[id[x]],1,n,mid);
if(pt[pos].val<=z){
l=mid+1;
}
else r=mid;
}
ans+=l-1;
x=f[top[x]];
}
if(id[x]>id[y])swap(x,y);
if(x==y){//最近公共祖先不考虑
printf("%d\n",ans);
return ;
}
int l=1,r=id[y]-id[son[x]]+2;//若在同一条重链,链的起始点不考虑
while(l<r){
int mid=(l+r)>>1;
int pos=query(root[id[son[x]]-1],root[id[y]],1,n,mid);
if(pt[pos].val<=z){
l=mid+1;
}
else r=mid;
}
ans+=l-1;
printf("%d\n",ans);
return ;
}
int main(){
int i,j,k,m;
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&m);
int x,y,z;
for(i=1;i<n;i++){
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);add(y,x,z);
}
dfs1(1,1);
dfs2(1,1);
sort(pt+1,pt+n+1,cmp);
for(i=1;i<=n;i++){
rk[pt[i].id]=i;
}
for(i=1;i<=n;i++){
root[i]=root[i-1];
Insert(rk[i],root[i],1,n);
}
int u,v;
for(i=0;i<m;i++){
scanf("%d%d%d",&u,&v,&z);
solve(u,v,z);
}
}