题意
树上区间修改、查询
题解
树链剖分+线段树
调试记录
线段树update没有pushdown
没有遍历重儿子
size没有跟新
dep大的dfn小
#include <cstdio>
#include <algorithm>
#define maxn 100005
using namespace std;
struct node{
int to, next;
}e[maxn << 1];
int head[maxn], n, m, root, mo, val[maxn], val2[maxn];
int f[maxn], dep[maxn], dfn[maxn], size[maxn], son[maxn], top[maxn];
struct node2{
int key, l, r, lazy;
}a[maxn << 2];
int tot = 0;
void addedge(int u, int v){
e[++tot] = (node){v, head[u]}; head[u] = tot;
}
int cnt = 0;
void dfs(int cur, int fa, int deep){
f[cur] = fa;
dep[cur] = deep;
size[cur] = 1;
int Max = -1;
for (int i = head[cur]; i; i = e[i].next){
if (e[i].to != fa){
dfs(e[i].to, cur, deep + 1);
size[cur] += size[e[i].to];
if (size[e[i].to] > Max) Max = size[e[i].to], son[cur] = e[i].to;
}
}
}
void dfs2(int cur, int topf){
top[cur] = topf;
dfn[cur] = ++cnt;
val2[cnt] = val[cur];
if (!son[cur]) return;
dfs2(son[cur], topf);
for (int i = head[cur]; i; i = e[i].next){
if (e[i].to != f[cur] && e[i].to != son[cur]){
dfs2(e[i].to, e[i].to);
}
}
}
void pushdown(int cur){
if (!a[cur].lazy) return;
int mid = (a[cur].l + a[cur].r) >> 1;
(a[cur << 1].key += (a[cur].lazy * (mid - a[cur].l + 1))) %= mo;
(a[cur << 1 | 1].key += (a[cur].lazy * (a[cur].r - mid))) %= mo;
(a[cur << 1].lazy += a[cur].lazy) %= mo;
(a[cur << 1 | 1].lazy += a[cur].lazy) %= mo;
a[cur].lazy = 0;
}
void update(int cur, int l, int r, int k){
if (a[cur].l > r || a[cur].r < l) return;
if (a[cur].l >= l && a[cur].r <= r){
(a[cur].key += (a[cur].r - a[cur].l + 1) * k) %= mo;
(a[cur].lazy += k) %= mo;
return;
}
pushdown(cur);
int mid = (a[cur].l + a[cur].r) >> 1;
update(cur << 1, l, r, k);
update(cur << 1 | 1, l, r, k);
a[cur].key = (a[cur << 1].key + a[cur << 1 | 1].key) % mo;
}
void build(int cur, int l, int r){
a[cur].l = l, a[cur].r = r;
if (l == r){
a[cur].key = val2[l] % mo;
return;
}
int mid = (l + r) >> 1;
build(cur << 1, l, mid);
build(cur << 1 | 1, mid + 1, r);
a[cur].key = (a[cur << 1].key + a[cur << 1 | 1].key) % mo;
}
int Query(int cur, int l, int r){
if (a[cur].l > r || a[cur].r < l) return 0;
if (a[cur].l >= l && a[cur].r <= r) return a[cur].key;
pushdown(cur);
return (Query(cur << 1, l, r) + Query(cur << 1 | 1, l, r)) % mo;
}
int qRange(int x, int y){
int ans = 0;
while (top[x] != top[y]){
if (dep[top[x]] < dep[top[y]]) swap(x, y);
(ans += Query(1, dfn[top[x]], dfn[x])) %= mo;
x = f[top[x]];
}
if (dep[x] > dep[y]) swap(x, y);
(ans += Query(1, dfn[x], dfn[y])) %= mo;
return ans;
}
void uRange(int x, int y, int k){
k %= mo;
while (top[x] != top[y]){
if (dep[top[x]] < dep[top[y]]) swap(x, y);
update(1, dfn[top[x]], dfn[x], k);
x = f[top[x]];
}
if (dep[x] > dep[y]) swap(x, y);
update(1, dfn[x], dfn[y], k);
}
int qSon(int x){
return Query(1, dfn[x], dfn[x] + size[x] - 1);
}
void uSon(int x, int k){
update(1, dfn[x], dfn[x] + size[x] - 1, k);
}
inline int read(){
int x = 0; char ch = getchar();
while (ch < '0' || ch > '9') ch = getchar();
while (ch >= '0' && ch <= '9') x = x * 10 + (ch - '0'), ch = getchar();
return x;
}
int main(){
n = read(), m = read(), root = read(), mo = read();
for (int i = 1; i <= n; i++) val[i] = read();
for (int x, y, i = 1; i < n; i++){
x = read(), y = read(); addedge(x, y); addedge(y, x);
}
dfs(root, 0, 1);
dfs2(root, root);
build(1, 1, n);
for (int opt, x, y, k, i = 1; i <= m; i++){
opt = read();
if (opt == 1){
x = read(), y = read(), k = read();
uRange(x, y, k);
}
if (opt == 2){
x = read(), y = read();
printf("%d\n", qRange(x, y));
}
if (opt == 3){
x = read(), k = read();
uSon(x, k);
}
if (opt == 4){
x = read();
printf("%d\n", qSon(x));
}
}
return 0;
}