题目:https://nanti.jisuanke.com/t/38229
题意:树上有边权,每次查询树上两点路径上边权小于等于k的数量。
思路:边权下推成点权,树剖化树为区间,区间查找小于k的数的个数。
代码:
#include <bits/stdc++.h>
#define LL long long
using namespace std;
typedef pair<int, int>P;
const int maxn = 1e5+5;
vector<P>e[maxn]; vector<int>vec;
struct T{int l, r, sum;}T[maxn<<5];
int n, m, root[maxn], tot, sz;
int cnt, rk[maxn], depth[maxn], father[maxn], id[maxn], siz[maxn], heavy_son[maxn], top[maxn], a[maxn];
void update(int &now, int old, int l, int r, int pos){
now = ++tot; T[now] = T[old]; T[now].sum ++;
if(l == r) return ; int mid = l+r >> 1;
if(pos <= mid) update(T[now].l, T[old].l, l, mid, pos);
else update(T[now].r, T[old].r, mid+1, r, pos);
}
int query(int x, int y, int l, int r, int L, int R){
if(L <= l && R >= r) return T[y].sum - T[x].sum;
int mid = l+r >> 1, res = 0;
if(L <= mid) res += query(T[x].l, T[y].l, l, mid, L, R);
if(R > mid) res += query(T[x].r, T[y].r, mid+1, r, L, R);
return res;
}
void dfs1(int u, int fa){
depth[u] = depth[fa] + 1;
heavy_son[u] = 0; siz[u] = 1; father[u] = fa;
for(auto it : e[u]){
int v = it.first, w = it.second;
if(v == fa) continue;
a[v] = w; dfs1(v, u);
siz[u] += siz[v];
if(siz[heavy_son[u]] < siz[v]) heavy_son[u] = v;
}
}
void dfs2(int u, int rt){
id[u] = ++cnt; rk[cnt] = a[u]; top[u] = rt;
if(heavy_son[u]) dfs2(heavy_son[u], rt);
for(auto it : e[u]){
int v = it.first;
if(v == father[u] || v == heavy_son[u]) continue;
dfs2(v, v);
}
}
int Query(int u, int v, int x){
int tu = top[u], tv = top[v], ans = 0;
while(tu != tv){
if(depth[tu] < depth[tv]) swap(u, v), swap(tu, tv);
ans += query(root[id[tu]-1], root[id[u]], 1, sz, 1, x);
u = father[tu]; tu = top[u];
}
if(u == v) return ans ;
if(id[u] > id[v]) swap(u, v);
ans += query(root[id[u]], root[id[v]], 1, sz, 1, x);
return ans ;
}
struct Q{int u, v, w;}q[maxn];
int getid(int x){return lower_bound(vec.begin(), vec.end(), x)-vec.begin()+1;}
int main()
{
int u, v, w;
scanf("%d%d", &n, &m);
for(int i=1; i<n; i++){
scanf("%d%d%d", &u, &v, &w);
e[u].push_back(P(v, w)); e[v].push_back(P(u, w));
vec.push_back(w);
}
for(int i=1; i<=m; i++){
scanf("%d%d%d", &u, &v, &w);
q[i] = Q{u, v, w};
vec.push_back(w);
}
sort(vec.begin(), vec.end()); vec.erase(unique(vec.begin(), vec.end()), vec.end());
sz = vec.size(); dfs1(1, 0); dfs2(1, 1);
for(int i=1; i<=n; i++) update(root[i], root[i-1], 1, sz, getid(rk[i]));
for(int i=1; i<=m; i++) printf("%d\n", Query(q[i].u, q[i].v, getid(q[i].w)));
}