题目链接:https://www.luogu.org/problemnew/show/P3384
例题:
【ZJOI2008】树的统计:https://www.luogu.org/problemnew/show/P2590
【HAOI2015】树上操作:https://www.luogu.org/problemnew/show/P3178
【USACO11DEC】Grass Planting:https://www.luogu.org/problemnew/show/P3038
树链剖分是解决树上问题的利器,它利用了线段树和 dfs d f s 序的性质,使其能在 log l o g 级复杂度内解决部分树上修改、查询操作
我们针对这样一棵树来学习树链剖分:
我们首先来思考比较容易的子树修改与查询问题
显然地,如果我们从根开始 dfs d f s ,得到这些节点编号的序列,即 dfs d f s 序,节点 x x 位置记为,同时统计包括该点在内的子树大小记为 siz[x] s i z [ x ] ,那么这个节点及其子树在 dfs序 d f s 序 所在区间为 [id[x],id[x]+siz[x]−1] [ i d [ x ] , i d [ x ] + s i z [ x ] − 1 ]
于是我们可以得到下面这段代码
dfs(x,f)
id[x]=++cnt
siz[x]=1
for(i=x的相邻节点)
if i≠f
dfs(i,x)
siz[x]+=siz[i]
end
cchg(x)
l←id[x],r←id[x]+siz[x]-1
chg(l,r)
end
asks(x)
l←id[x],r←id[x]+siz[x]-1
return ask(l,r)
end
接下来才是重头戏,我们不仅要解决子树的问题,还要解决链上的问题
于是就有了轻重链剖分
所谓重链,就是连续重边连成的链;所谓重边,就是某子树根节点与其重儿子连成的边;所谓重儿子,就是 siz s i z 最大的那一个儿子。
↑这是经过剖分的树,虚线为轻边,实线为重边
我们可以得出寻找重儿子的递归代码:
dfs(x,f)
siz[x]=1
hwy=0;
for(i=x的相邻节点)
if i≠f
dfs(i,x)
siz[x]+=siz[i]
if hwy<siz[i]
sn[x]=i
hwy=siz[i]
end
但是仅仅找出重儿子还不够,我们还要将其利用起来
于是我们需要记录重链
于是我们需要记录链头并保证重链在 dfs d f s 序中连续
于是就有了下面这样的两段dfs
void dfs1(int x,int f,int deep){
tp[x]=deep;
siz[x]=1;
fa[x]=f;
int hwy=0;
for(int i=h[x];i;i=a[i].li){
if(a[i].nx!=f){
dfs1(a[i].nx,x,deep+1);
siz[x]+=siz[a[i].nx];
if(siz[a[i].nx]>hwy){
sn[x]=a[i].nx;
hwy=siz[a[i].nx];
}
}
}return;
}
void dfs2(int x,int tpx){
id[x]=++cnt;
top[x]=tpx;
if(!sn[x]) return;
dfs2(sn[x],tpx);
for(int i=h[x];i;i=a[i].li){
if(a[i].nx!=fa[x]&&a[i].nx!=sn[x]){
dfs2(a[i].nx,a[i].nx);
}
}return;
}
dfs1(root,0,1);
dfs2(root,root);
tp[x] t p [ x ] 记录深度, top[x] t o p [ x ] 记录链顶
这样对于同一条重链上的修改和查询就迎刃而解了
接下来我们要解决最后一个问题,如果两个节点不在同一条重链上怎么办?
像求 LCA L C A 一样向上跳!
下面给出类似于 LCA L C A 求解的跳跃代码
void jump(int x,int y){
while(top[x]!=top[y]){
if(tp[top[x]]<tp[top[y]]) swap(x,y);
x所在重链链上操作
x=fa[top[x]];
}if(tp[x]>tp[y]) swap(x,y);
x,y同一条重链链上操作
return;
}
至于 dfs d f s 序,我们用线段树维护就好
下面是完整的模板↓
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int MAXN=1<<17;
int n,m,cnt,root,MOD,x,y,np,z,p;
int tp[MAXN],fa[MAXN],siz[MAXN],sn[MAXN],val[MAXN];
int h[MAXN],id[MAXN],ren[MAXN],top[MAXN];
int tree[MAXN<<1],pls[MAXN<<1];
struct rpg{
int li,nx;
}a[MAXN<<1];
void add(int ls,int nx){
a[++np]=(rpg){h[ls],nx};
h[ls]=np;
}
void po(int k,int l,int r){
if(l==r||pls[k]==0) return;
int i=k<<1,mid=l+r>>1;
pls[i]=(pls[i]+pls[k])%MOD;
pls[i|1]=(pls[i|1]+pls[k])%MOD;
tree[i]=(tree[i]+pls[k]*(mid-l+1))%MOD;
tree[i|1]=(tree[i|1]+pls[k]*(r-mid))%MOD;
pls[k]=0;
}
void cadd(int k,int l,int r,int le,int ri,int x){
po(k,l,r);
if(le<=l&&r<=ri){
pls[k]=x;
tree[k]=(tree[k]+x*(r-l+1))%MOD;
return;
}int i=k<<1,mid=l+r>>1;
if(le<=mid) cadd(i,l,mid,le,ri,x);
if(mid<ri) cadd(i|1,mid+1,r,le,ri,x);
tree[k]=(tree[i]+tree[i|1])%MOD;
}
int ask(int k,int l,int r,int le,int ri){
po(k,l,r);
if(le<=l&&r<=ri) return tree[k]%MOD;
int i=k<<1,mid=l+r>>1;
int sum=0;
if(le<=mid) sum=(sum+ask(i,l,mid,le,ri))%MOD;
if(mid<ri) sum=(sum+ask(i|1,mid+1,r,le,ri))%MOD;
return sum;
}
void cadd1(int x,int y,int z){
while(top[x]!=top[y]){
if(tp[top[x]]<tp[top[y]]) swap(x,y);
cadd(1,1,n,id[top[x]],id[x],z);
x=fa[top[x]];
}if(tp[x]>tp[y]) swap(x,y);
cadd(1,1,n,id[x],id[y],z);
return;
}
int ask1(int x,int y){
long long sum=0;
while(top[x]!=top[y]){
if(tp[top[x]]<tp[top[y]]) swap(x,y);
sum+=ask(1,1,n,id[top[x]],id[x]);
if(sum>=MOD) sum%=MOD;
x=fa[top[x]];
}if(tp[x]>tp[y]) swap(x,y);
sum+=ask(1,1,n,id[x],id[y]);
return sum%MOD;
}
void init(){
scanf("%d%d%d%d",&n,&m,&root,&MOD);
for(int i=1;i<=n;++i){
scanf("%d",&val[i]);
}for(int i=1;i<n;++i){
scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}return;
}
void dfs1(int x,int f,int deep){
tp[x]=deep;
fa[x]=f;
siz[x]=1;
int hwy=0;
for(int i=h[x];i;i=a[i].li){
if(a[i].nx!=f){
dfs1(a[i].nx,x,deep+1);
siz[x]+=siz[a[i].nx];
if(siz[a[i].nx]>hwy){
sn[x]=a[i].nx;
hwy=siz[a[i].nx];
}
}
}return;
}
void dfs2(int x,int topx){
id[x]=++cnt;
ren[cnt]=val[x];
top[x]=topx;
if(!sn[x]) return;
dfs2(sn[x],topx);
for(int i=h[x];i;i=a[i].li){
if(a[i].nx!=fa[x]&&a[i].nx!=sn[x]){
dfs2(a[i].nx,a[i].nx);
}
}return;
}
void build(int k,int l,int r){
if(l==r){
tree[k]=ren[l];
return;
}int i=k<<1,mid=l+r>>1;
build(i,l,mid);
build(i|1,mid+1,r);
tree[k]=(tree[i]+tree[i|1])%MOD;
}
void solve(){
while(m--){
scanf("%d",&p);
if(p==1) scanf("%d%d%d",&x,&y,&z),cadd1(x,y,z%MOD);
else if(p==2) scanf("%d%d",&x,&y),printf("%d\n",ask1(x,y));
else if(p==3) scanf("%d%d",&x,&z),cadd(1,1,n,id[x],id[x]+siz[x]-1,z%MOD);
else scanf("%d",&x),printf("%d\n",ask(1,1,n,id[x],id[x]+siz[x]-1));
}return;
}
int main(){
init();
dfs1(root,0,1);
dfs2(root,root);
build(1,1,n);
solve();
return 0;
}