传送门
分析
上午刚写了一个
d
f
s
dfs
dfs序处理主席树,下午多校就来一个差不多的
可以推出来结论是,如果根节点符合要求的话,那么子树中所有在范围内的节点,必然是连续的,所以问题转换成了,求
x
x
x节点子树内,所有符合条件的节点个数
首先需要向上求出来,最远的符合条件的节点,然后以这个节点为根节点,这个可以用倍增维护,剩下的就离散化一下,主席树上搞一搞就出来了
代码
#pragma GCC optimize(3)
#include <bits/stdc++.h>
#define debug(x) cout<<#x<<":"<<x<<endl;
#define dl(x) printf("%lld\n",x);
#define di(x) printf("%d\n",x);
#define _CRT_SECURE_NO_WARNINGS
#define pb push_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef vector<int> VI;
const int INF = 0x3f3f3f3f;
const int N = 1e5 + 10, M = 2 * N;
const ll mod = 1000000007;
const double eps = 1e-9;
const double PI = acos(-1);
template<typename T>inline void read(T &a) {
char c = getchar(); T x = 0, f = 1; while (!isdigit(c)) {if (c == '-')f = -1; c = getchar();}
while (isdigit(c)) {x = (x << 1) + (x << 3) + c - '0'; c = getchar();} a = f * x;
}
int gcd(int a, int b) {return (b > 0) ? gcd(b, a % b) : a;}
int n, m;
vector<int> nums;
struct Node {
int l, r;
int cnt;
} tr[N * 66];
int root[N], idx;
int h[N], ne[M], e[M], num;
int in[N], out[N], cnt;
int fa[N][25], mx[N][25];
int w[N];
void add(int x, int y) {
ne[num] = h[x], e[num] = y, h[x] = num++;
}
int find(int x) {
return lower_bound(nums.begin(), nums.end(), x) - nums.begin();
}
int build(int l, int r) {
int p = ++idx;
if(l == r) return p;
int mid = l + r >> 1;
tr[p].l = build(l,mid),tr[p].r= build(mid + 1,r);
return p;
}
int insert(int p, int l, int r, int x) {
int q = ++idx;
tr[q] = tr[p];
if (l == r) {
tr[q].cnt++;
return q;
}
int mid = l + r >> 1;
if (x <= mid) tr[q].l = insert(tr[p].l, l, mid, x);
else tr[q].r = insert(tr[p].r, mid + 1, r, x);
tr[q].cnt = tr[tr[q].l].cnt + tr[tr[q].r].cnt;
return q;
}
void dfs(int u, int f) {
in[u] = ++cnt;
fa[u][0] = f;
root[cnt] = insert(root[cnt - 1], 0, nums.size() - 1, find(w[u]));
if (u != 1) mx[u][0] = w[f];
else mx[u][0] = INF;
for (int k = 1; k <= 20; k++) {
fa[u][k] = fa[fa[u][k - 1]][k - 1];
mx[u][k] = max(mx[u][k - 1], mx[fa[u][k - 1]][k - 1]);
}
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == f) continue;
dfs(j, u);
}
out[u] = cnt;
}
int ans_up(int x, int k) {
for (int i = 20; i >= 0; i--) {
if (fa[x][i] && mx[x][i] <= k) x = fa[x][i];
}
return x;
}
int query(int u, int l, int r, int L, int R) {
if (nums[l] >= L && nums[r] <= R) {
return tr[u].cnt;
}
int mid = l + r >> 1;
int ans = 0;
if (L <= nums[mid]) ans = query(tr[u].l, l, mid, L, R);
if (R >= nums[mid + 1]) ans += query(tr[u].r, mid + 1, r, L, R);
return ans;
}
int main() {
memset(h, -1, sizeof h);
read(n);
for (int i = 1; i < n; i++) {
int a, b;
read(a), read(b);
add(a, b), add(b, a);
}
for (int i = 1; i <= n; i++) read(w[i]), nums.push_back(w[i]);
nums.pb(INF);
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
root[0] = build(0, nums.size() - 1);
dfs(1, 0);
read(m);
while (m--) {
int x, l, r;
read(x), read(l), read(r);
if (w[x] < l || w[x] > r) {
puts("0");
continue;
}
x = ans_up(x, r);
int ans1 = query(root[out[x]], 0, nums.size() - 1, l, r);
int ans2 = query(root[in[x] - 1], 0, nums.size() - 1, l, r);
di(ans1 - ans2);
}
return 0;
}