何为树链剖分
树链剖分,顾名思义,就是将一棵树解剖,映射到一个链上,从而进行
1.子树/路径修改
2.子树/路径查询(此处路径为路径上的点)
所涉名词
-
重儿子 : 设一个节点x的子节点中size最大的是u,则称u为此节点的重儿子。
-
轻儿子 : 除重儿子外的其他节点。
-
重边 :x与u之间的边。
-
轻边 :与其他儿子的连边。
-
重链 :由 重边 连成的路径。
-
轻链 : 即 轻边 连成的路径。
一道板子题
代码解析
所涉数组:
in[]/out[]:dfs序;
head[]:存边基本操作;
num[]:映射完成后的链;
dep[]:节点深度;
son[]:重儿子;
fa[]:父节点;
top[]:链头(对于在链上快速跳跃有重大意义);
siz[]:子树大小;
pos[]:一个节点在链上的位置;
我相信线段树的相关操作你们已经会了,所以只解释其他函数。
预处理
void dfs1(int k)
{
siz[k]=1;
for(int i=head[k];i>=0;i=e[i].next)
if(e[i].y!=fa[k])
{
fa[e[i].y]=k;
dep[e[i].y]=dep[k]+1;
dfs1(e[i].y);
siz[k]+=siz[e[i].y];
if(siz[son[k]]<siz[e[i].y])son[k]=e[i].y;
}
}
void dfs2(int k,int tp)
{
top[k]=tp;
pos[k]=++cnt;
in[k]=cnt;
num[cnt]=a[k];
if(son[k])dfs2(son[k],tp);//是重链时继承
for(int i=head[k];i>=0;i=e[i].next)
if(e[i].y!=fa[k]&&e[i].y!=son[k])
dfs2(e[i].y,e[i].y);//是轻链时分家
out[k]=cnt;
}
第一个dfs处理出处理了深度,子树大小,重儿子等问题,相当于种下了树苗;
第二个dfs则处理了链头,位置,dfs序等问题,还顺带考虑了轻重链。
这里用到了dfs序,dfs序可以理解为dfs路线的步骤,当走到这一点时记录一下,in[i]=cnt++,离开时再记录一下,out[i]=cnt,于是就得到了下图(点内表示in[i]/out[i])

很容易发现,子树的dfs序是一个连续的区间,通过这一点,我们可以轻易地将子树修改转化为区间修改。根据需要,用线段树,树状数组,splay(还不会)等数据结构维护即可。
修改
当我们处理完轻重链时,就要考虑修改的问题。
(每个题的修改都是不一样的,我当然只会这个模板里的(逃))
1.将x到y的路径上都加上z。
void addpath()
{
scanf("%d%d%d",&x,&y,&z);
z%=p;
while(top[x]!=top[y])//比较链头
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
add(1,pos[top[x]],pos[x],z);
x=fa[top[x]];//跳出链
}
if(dep[x]<dep[y])swap(x,y);//修改之后深度不保证,所以默认一下,便于之后的操作。
add(1,pos[y],pos[x],z);//y比x浅,肯定比x先遍历到。
}
这里用到了类似于LCA的思想。我们在实际处理时,只需要每次看x与y在不在同一条链上(默认x深于y),不在的话,x跳出此链,并对这条链进行修改。当他们在一条链上时,就可以直接操作了。
如图
寻找8到5 ,8会直接跳到2,再进行操作。
2.将x子树加上z。
直接给in[x]到out[x]加上z就可以了,简单的不行(小声比比)
查询
与修改大同小异,这里就不再赘述了。
代码全貌
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<string>
#include<cstring>
#define maxn 100005
using namespace std;
int n,m,R,p;
struct node{
int x,y,next;
}e[maxn*2];
struct line{
int l,r,sum,tag;
}t[maxn*4];
int in[maxn],out[maxn],head[maxn],num[maxn];
int dep[maxn],son[maxn],fa[maxn],top[maxn],siz[maxn],pos[maxn];
int a[maxn];
int cnt=0;
int tot=0;
int x,y,z;
void ad(int x,int y)
{
++tot;
e[tot].x=x;e[tot].y=y;e[tot].next=head[x];head[x]=tot;
}
void dfs1(int k)
{
siz[k]=1;
for(int i=head[k];i>=0;i=e[i].next)
if(e[i].y!=fa[k])
{
fa[e[i].y]=k;
dep[e[i].y]=dep[k]+1;
dfs1(e[i].y);
siz[k]+=siz[e[i].y];
if(siz[son[k]]<siz[e[i].y])son[k]=e[i].y;
}
}
void dfs2(int k,int tp)
{
top[k]=tp;
pos[k]=++cnt;
in[k]=cnt;
num[cnt]=a[k];
if(son[k])dfs2(son[k],tp);
for(int i=head[k];i>=0;i=e[i].next)
if(e[i].y!=fa[k]&&e[i].y!=son[k])
dfs2(e[i].y,e[i].y);
out[k]=cnt;
}
void biu(int k,int l,int r)
{
t[k].l=l;t[k].r=r;
if(l==r){t[k].sum=num[l];return ;}
int mid=(l+r)>>1;
biu(k<<1,l,mid);biu(k<<1|1,mid+1,r);
t[k].sum=t[k<<1].sum+t[k<<1|1].sum;
}
void pushd(int k)
{
if(t[k].tag)
{
if(t[k].l==t[k].r)
{
t[k].tag=0;return ;
}
t[k<<1].tag+=t[k].tag;t[k<<1].tag%=p;
t[k<<1|1].tag+=t[k].tag;t[k<<1|1].tag%=p;
t[k<<1].sum+=t[k].tag*(t[k<<1].r-t[k<<1].l+1);t[k<<1].sum%=p;
t[k<<1|1].sum+=t[k].tag*(t[k<<1|1].r-t[k<<1|1].l+1);t[k<<1|1].sum%=p;
}
t[k].tag=0;
}
void add(int k,int l,int r,int f)
{
pushd(k);
if(t[k].l==l&&t[k].r==r)
{
t[k].tag=f;
t[k].sum+=f*(t[k].r-t[k].l+1);
return ;
}
int mid=(t[k].l+t[k].r)>>1;
if(r<=mid)add(k<<1,l,r,f);
else if(l>mid)add(k<<1|1,l,r,f);
else add(k<<1,l,mid,f),add(k<<1|1,mid+1,r,f);
t[k].sum=t[k<<1].sum+t[k<<1|1].sum;t[k].sum%=p;
}
int ask(int k,int l,int r)
{
pushd(k);
if(t[k].l==l&&t[k].r==r)return t[k].sum;
int mid=(t[k].l+t[k].r)>>1;
if(r<=mid)return ask(k<<1,l,r);
else if(l>mid)return ask(k<<1|1,l,r);
else return (ask(k<<1,l,mid)+ask(k<<1|1,mid+1,r))%p;
}
void addpath()
{
scanf("%d%d%d",&x,&y,&z);
z%=p;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
add(1,pos[top[x]],pos[x],z);
x=fa[top[x]];
}
if(dep[x]<dep[y])swap(x,y);
add(1,pos[y],pos[x],z);
}
void addtree()
{
scanf("%d%d",&x,&z);
z%=p;
add(1,in[x],out[x],z);
}
void askpath()
{
scanf("%d%d",&x,&y);
int ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
ans+=ask(1,pos[top[x]],pos[x]);ans%=p;
x=fa[top[x]];
}
if(dep[x]<dep[y])swap(x,y);
ans+=ask(1,pos[y],pos[x]);ans%=p;
printf("%d\n",ans);
}
void asktree()
{
scanf("%d",&x);
printf("%d\n",ask(1,in[x],out[x])%p);
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d%d%d",&n,&m,&R,&p);
for(int i=1;i<=n;i++)scanf("%d",&a[i]),a[i]%=p;
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
ad(a,b);ad(b,a);
}
dfs1(R);
dfs2(R,R);
biu(1,1,n);
int op;
while(m--)
{
scanf("%d",&op);
if(op==1)addpath();
else if(op==3)addtree();
else if(op==2)askpath();
else asktree();
}
}