参考(还有图,讲的明白!)
树链剖分:即把树分成多条链来进行所需要的操作,比较常见的的是重链剖分。先介绍部分概念
重儿子:一个节点的所有儿子中子树最大的那个儿子
轻儿子:除了重儿子之外的儿子
重边:连接重儿子的边称为重边
轻边:连接轻儿子的边
重链:由重边组成的树链;
如果单个结点也可视为重链,就可以把整棵树剖分成多条重链。
我们可以通过一次dfs找出每个结点的重儿子,然后再进行一次dfs把树剖分成多条重链,我们确定一个结点属于哪个重链只需要看它所处链的顶部结点是哪一个,设top[i],表示结点i所处重链的顶部结点,如果top[i] = top[j],则i,j处于同一条链上,然后结合dfs序的知识,在第二次dfs的同时,按照先遍历重儿子的原则进行标号,那么对于最后得到的dfs序,有以下的性质
1.同一条重链中所有节点的dfs序是一个连续的区间
2.一个结点的及其子树也是一个连续的区间
所以我们就可以用线段树,树状数组来维护最大最小值,求和等
但是如果求解的结点是在同一条链上还好说,直接就是一个连续区间,不在同一条链上的话,我们就需要借助于类似倍增求解LCA的方式来进行向上跳跃,直到二者处于同一条重链上为止;
#include<bits/stdc++.h>
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<map>
#include<cstring>
#include<queue>
#define lson (rt << 1)
#define rson (rt << 1 | 1)
const int maxn = 1e5+10;
const int maxm = 5e4 + 10;
const int inf_max = 0x3f3f3f;
using namespace std;
typedef long long ll;
struct EDG {
int v,nxt;
EDG() {}
EDG(int tv,int tn) {v = tv,nxt = tn;}
}edge[maxn<<1];
int hson[maxn],head[maxn],cnt,tot,redfn[maxn],dfnl[maxn],val[maxn],dfnr[maxn],sz[maxn],fa[maxn],top[maxn],dep[maxn],sum[maxn<<2],lazy[maxn<<2],n,m,p,r;
//hson[i]:结点i的重儿子
//redfn[i]:dfs序为i的结点标号是redfn[i];
//dfnl[i]:i号结点的dfs序
//top[i]:i所处链的顶端
//dep[i]:i的深度
//fa[i]:i的父亲
//sz[i]:i的子树大小,用于确定重儿子
void add_edge(int u,int v) {
edge[cnt] = EDG(v,head[u]);
head[u] = cnt++;
}
//重链paofen
int dfs1(int now,int pre,int deep) {
dep[now] = deep;sz[now] = 1;
for(int i = head[now]; ~i;i = edge[i].nxt) {
int v = edge[i].v;
if(v == pre) continue;
fa[v] = now;
sz[now] += dfs1(v,now,deep + 1);
if(hson[now] == -1 || sz[hson[now]] < sz[v]) hson[now] = v;
}
return sz[now];
}
void dfs2(int now,int pre,int tp) {
dfnl[now] = ++tot;top[now] = tp;redfn[tot] = now;
if(hson[now] == -1) {dfnr[now] = tot;return ;}
dfs2(hson[now],now,tp);
for(int i = head[now];~i;i = edge[i].nxt) {
int v = edge[i].v;
if(v == hson[now] || v == pre) continue;
dfs2(v,now,v);
}
dfnr[now] = tot;
}
//线段树
void pushup(int rt) {
sum[rt] = (sum[rson] + sum[lson]) % p;
}
void pushdown(int rt,int l,int r) {
if(lazy[rt]) {
int mid = (l + r) >> 1;
lazy[lson] += lazy[rt],lazy[rson] += lazy[rt];
sum[lson] = (sum[lson] + (mid + 1 - l)*lazy[rt]%p)%p,sum[rson] = (sum[rson] + (r - mid) * lazy[rt]%p)%p;
lazy[rt] = 0;
}
}
void build(int rt,int l,int r) {
lazy[rt] = 0;
if(l == r) {
sum[rt] = val[redfn[r]] % p;
return ;
}
int mid = (l+r) >> 1;
build(lson,l,mid);
build(rson,mid + 1, r);
pushup(rt);
}
void add(int rt,int l,int r,int ql,int qr,int v) {
if(l == ql && r == qr) {
sum[rt] = (sum[rt] + (r + 1 - l) * v % p) % p;
lazy[rt] += v;
return ;
}
pushdown(rt,l,r);
int mid = (l + r) >> 1;
if(qr <= mid) add(lson,l,mid,ql,qr,v);
else if(ql > mid) add(rson,mid + 1,r,ql,qr,v);
else {
add(lson,l,mid,ql,mid,v),add(rson,mid + 1,r,mid + 1,qr,v);
}
pushup(rt);
}
int query(int rt,int l,int r,int ql,int qr) {
if(l == ql && r == qr) return sum[rt];
pushdown(rt,l,r);
int mid = (l + r) >> 1;
if(qr <= mid) return query(lson,l,mid,ql,qr);
else if(ql > mid) return query(rson,mid + 1,r,ql,qr);
else return (query(lson,l,mid,ql,mid)+ query(rson,mid + 1,r,mid + 1,qr)) % p;
}
int main()
{
memset(hson,-1,sizeof(hson));
memset(head,-1,sizeof(head));
scanf("%d%d%d%d",&n,&m,&r,&p);
cnt = tot = 0;fa[r] = r;
for(int i = 1;i <= n; ++i) scanf("%d",&val[i]);
for(int i = 1;i < n; ++i) {
int u,v;
scanf("%d%d",&u,&v);
add_edge(u,v),add_edge(v,u);
}
sz[r] = dfs1(r,-1,1);
dfs2(r,-1,r);
build(1,1,n);
for(int i = 1;i <= m; ++i) {
int op,x,y,z;
scanf("%d",&op);
if(op == 1) {
scanf("%d%d%d",&x,&y,&z);
int fx = top[x],fy = top[y];
while(fx != fy) {
if(dep[fx] > dep[fy]) {add(1,1,n,dfnl[fx],dfnl[x],z),x = fa[fx];}
else {add(1,1,n,dfnl[fy],dfnl[y],z),y = fa[fy];}
fx = top[x],fy = top[y];
}
if(dep[x] > dep[y]) add(1,1,n,dfnl[y],dfnl[x],z);
else add(1,1,n,dfnl[x],dfnl[y],z);
}else if(op == 2) {
int ans = 0;
scanf("%d%d",&x,&y);
int fx = top[x],fy = top[y];
while(fx != fy) {
if(dep[fx] > dep[fy]) { ans = (ans + query(1,1,n,dfnl[fx],dfnl[x]))%p,x = fa[fx]; }
else { ans = (ans + query(1,1,n,dfnl[fy],dfnl[y]))%p, y = fa[fy];}
fx = top[x],fy = top[y];
}
if(dep[x] > dep[y]) ans = (ans + query(1,1,n,dfnl[y],dfnl[x]))%p;
else ans = (ans + query(1,1,n,dfnl[x],dfnl[y]))%p;
printf("%d\n",ans);
}else if(op == 3) {
scanf("%d%d",&x,&z);
add(1,1,n,dfnl[x],dfnr[x],z);
}else if(op == 4) {
scanf("%d",&x);
printf("%d\n",query(1,1,n,dfnl[x],dfnr[x]));
}
}
return 0;
}