树链刨分模板题。
推荐一篇讲的比较好的博客:https://www.cnblogs.com/chinhhh/p/7965433.html
因为这题要求取模,所以线段树部分和树链刨分部分都要取模。
ACcode:
精简版:
#include<bits/stdc++.h>
const int N=1e5+10;
using namespace std;
int n,q,r,mod;
int a[N<<2],b[N<<2];
int tot,ver[N<<1],Next[N<<1],head[N<<1];
int dep[N],fa[N],Size[N],son[N];
int cnt,id[N],top[N];
/*
b[i] 初始节点的权值
a[i] 新建节点的权值
题目输入的是b[i],线段树中用的是a[i]
dep[i] 该节点的深度
fa[i] 该节点的父亲
Size[i] 该非叶子节点的子树大小
son[i] i的重儿子
cnt 新建节点时用来存编号的
id[i] 该节点的新编号
top[i] 该节点所在链的顶端节点编号
*/
void add(int x,int y)
{
ver[++tot]=y;Next[tot]=head[x];head[x]=tot;
}
void dfs1(int x,int f,int deep) //当前节点 父亲 深度
{
dep[x]=deep;
fa[x]=f;
Size[x]=1;
int maxson=-1;
for(int i=head[x];i;i=Next[i])
{
int y=ver[i];
if(y==f) continue;
dfs1(y,x,deep+1);
Size[x]+=Size[y];
if(Size[y]>maxson)
son[x]=y,maxson=Size[y];
}
}
void dfs2(int x,int topf) //当前节点 当前链的最顶端的节点
{
id[x]=++cnt;
a[cnt]=b[x];
top[x]=topf;
if(son[x]==0) return ;
dfs2(son[x],topf);
for(int i=head[x];i;i=Next[i])
{
int y=ver[i];
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
//以下是线段树******************************************************************
struct stree
{
int l,r;
int sum,add;
} t[N<<2];
void build(int p,int l,int r) //建树 build(1,1,n)
{
t[p].l=l,t[p].r=r;
if(l==r)
{
t[p].sum=a[l];
return ;
}
int mid=(l+r)/2;
build(p*2,l,mid);
build(p*2+1,mid+1,r);
t[p].sum=(t[p*2].sum+t[p*2+1].sum)%mod;
}
void spread(int p)
{
if(t[p].add)
{
t[p*2].sum=(t[p*2].sum+t[p].add*(t[p*2].r-t[p*2].l+1)%mod)%mod;
t[p*2+1].sum=(t[p*2+1].sum+t[p].add*(t[p*2+1].r-t[p*2+1].l+1)%mod)%mod;
t[p*2].add+=t[p].add;
t[p*2+1].add+=t[p].add;
t[p].add=0;
}
}
//更新节点信息,调用入口:update(1,l,r,d),把第l到r个数都加d
void update(int p,int l,int r,int d)
{
if(l<=t[p].l&&r>=t[p].r)
{
t[p].sum=(t[p].sum+(int)d*(t[p].r-t[p].l+1)%mod)%mod;
t[p].add+=d;
return ;
}
spread(p);
int mid=(t[p].l+t[p].r)/2;
if(l<=mid) update(p*2,l,r,d);
if(r>mid) update(p*2+1,l,r,d);
t[p].sum=(t[p*2].sum+t[p*2+1].sum)%mod;
}
//第l到r个数的和,调用入口:query(1,l,r)
int query(int p,int l,int r)
{
if(l<=t[p].l&&r>=t[p].r)
return t[p].sum;
spread(p);
int mid=(t[p].l+t[p].r)/2;
int val=0;
if(l<=mid)
val=(val+query(p*2,l,r))%mod;
if(r>mid)
val=(val+query(p*2+1,l,r))%mod;
return val;
}
//以上是线段树*********************************************************************
void update1(int x,int y,int z) //给节点x到节点y这条链上的所有节点都加上z
{
z%=mod;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
update(1,id[top[x]],id[x],z);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
update(1,id[x],id[y],z);
}
int query1(int x,int y) //询问节点x到节点y这条链上的所有节点的和
{
int ret=0,res;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
res=query(1,id[top[x]],id[x]);
ret+=res;
ret%=mod;
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ret+=query(1,id[x],id[y]);
ret%=mod;
return ret;
}
void update2(int x,int z) //给以x为根的子树的所有节点(包括自己)都加上z
{
update(1,id[x],id[x]+Size[x]-1,z);
}
int query2(int x) //询问以x为根的子树的所有节点(包括自己)的和
{
return query(1,id[x],id[x]+Size[x]-1)%mod;
}
int main()
{
tot=0;
cnt=0;
scanf("%d%d%d%d",&n,&q,&r,&mod);
for(int i=1;i<=n;i++) scanf("%d",&b[i]);
int x,y;
for(int i=1;i<=n-1;i++)
{scanf("%d%d",&x,&y);add(x,y);add(y,x);}
dfs1(r,0,1);
dfs2(r,r);
build(1,1,n);
int ans;
while(q--)
{
int op,x,y,z;
scanf("%d",&op);
if(op==1)
{
scanf("%d%d%d",&x,&y,&z);
update1(x,y,z);
}
else if(op==2)
{
scanf("%d%d",&x,&y);
ans=query1(x,y);
printf("%d\n",ans);
}
else if(op==3)
{
scanf("%d%d",&x,&z);
update2(x,z);
}
else if(op==4)
{
scanf("%d",&x);
ans=query2(x);
printf("%d\n",ans);
}
}
return 0;
}
带注释版:
#include<bits/stdc++.h>
const int N=1e5+10;
using namespace std;
int n,q,r,mod;
int a[N<<2],b[N<<2];
int tot,ver[N<<1],Next[N<<1],head[N<<1]; //链式前向星加边
int dep[N],fa[N],Size[N],son[N];
int cnt,id[N],top[N];
/*
b[i] 初始节点的权值
a[i] 新建节点的权值
题目输入的是b[i],线段树中用的是a[i]
dep[i] 该节点的深度
fa[i] 该节点的父亲
Size[i] 该非叶子节点的子树大小
son[i] i的重儿子
cnt 新建节点时用来存编号的
id[i] 该节点的新编号
top[i] 该节点所在链的顶端节点编号
*/
void add(int x,int y)
{
ver[++tot]=y;Next[tot]=head[x];head[x]=tot;
}
void dfs1(int x,int f,int deep) //当前节点 父亲 深度
{
dep[x]=deep; //标记每个点的深度
fa[x]=f; //标记每个点的父亲
Size[x]=1; //标记每个非叶子节点的子树大小
int maxson=-1; //记录重儿子的儿子数
for(int i=head[x];i;i=Next[i])
{
int y=ver[i];
if(y==f) continue; //若该节点为父亲节点则continue
dfs1(y,x,deep+1); //dfs其儿子
Size[x]+=Size[y]; //统计子树大小
if(Size[y]>maxson) //查找每个非叶子节点的重儿子编号
son[x]=y,maxson=Size[y];
}
}
void dfs2(int x,int topf) //当前节点 当前链的最顶端的节点
{
id[x]=++cnt; //标记每个节点的新编号
a[cnt]=b[x]; //把每个节点的初始值赋给这个新编号
top[x]=topf; //记录这个点所在重链的最顶端
if(son[x]==0) return ; //如果没有儿子,直接返回
dfs2(son[x],topf); //先处理重儿子,在处理轻儿子
for(int i=head[x];i;i=Next[i])
{
int y=ver[i];
if(y==fa[x]||y==son[x]) continue; //如果该节点是父亲节点或者重儿子,continue
dfs2(y,y); //每个轻儿子都有一条以自己为链顶端的重链
}
}
//以下是线段树******************************************************************
struct stree
{
int l,r;
int sum,add;
} t[N<<2];
void build(int p,int l,int r) //建树 build(1,1,n)
{
t[p].l=l,t[p].r=r;
if(l==r)
{
t[p].sum=a[l];
return ;
}
int mid=(l+r)/2;
build(p*2,l,mid);
build(p*2+1,mid+1,r);
t[p].sum=(t[p*2].sum+t[p*2+1].sum)%mod;
}
void spread(int p)
{
if(t[p].add)
{
t[p*2].sum=(t[p*2].sum+t[p].add*(t[p*2].r-t[p*2].l+1)%mod)%mod;
t[p*2+1].sum=(t[p*2+1].sum+t[p].add*(t[p*2+1].r-t[p*2+1].l+1)%mod)%mod;
t[p*2].add+=t[p].add;
t[p*2+1].add+=t[p].add;
t[p].add=0;
}
}
//更新节点信息,调用入口:update(1,l,r,d),把第l到r个数都加d
void update(int p,int l,int r,int d)
{
if(l<=t[p].l&&r>=t[p].r)
{
t[p].sum=(t[p].sum+(int)d*(t[p].r-t[p].l+1)%mod)%mod;
t[p].add+=d;
return ;
}
spread(p);
int mid=(t[p].l+t[p].r)/2;
if(l<=mid) update(p*2,l,r,d);
if(r>mid) update(p*2+1,l,r,d);
t[p].sum=(t[p*2].sum+t[p*2+1].sum)%mod;
}
//第l到r个数的和,调用入口:query(1,l,r)
int query(int p,int l,int r)
{
if(l<=t[p].l&&r>=t[p].r)
return t[p].sum;
spread(p);
int mid=(t[p].l+t[p].r)/2;
int val=0;
if(l<=mid)
val=(val+query(p*2,l,r))%mod;
if(r>mid)
val=(val+query(p*2+1,l,r))%mod;
return val;
}
//以上是线段树*********************************************************************
void update1(int x,int y,int z) //给节点x到节点y这条链上的所有节点都加上z
{
//update1注释同query1
z%=mod;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
update(1,id[top[x]],id[x],z);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
update(1,id[x],id[y],z);
}
int query1(int x,int y) //询问节点x到节点y这条链上的所有节点的和
{
int ret=0,res;
while(top[x]!=top[y]) //这两个点不在一条链上
{
if(dep[top[x]]<dep[top[y]]) swap(x,y); //设x点为所在链顶端节点的深度更深的那个点
res=query(1,id[top[x]],id[x]); //ret加上“x点到x所在链顶端”这一段区间的点权和
ret+=res;
ret%=mod;
x=fa[top[x]];//把x跳到x所在链顶端的那个点的上面一个点
}
//此时两个点已经处于一条链上了
if(dep[x]>dep[y]) swap(x,y); //设x点为所在链顶端节点的深度更深的那个点
ret+=query(1,id[x],id[y]); //再加上两个点的区间和
ret%=mod;
return ret;
}
void update2(int x,int z) //给以x为根的子树的所有节点(包括自己)都加上z
{
//update2注释同query2
update(1,id[x],id[x]+Size[x]-1,z);
}
int query2(int x) //询问以x为根的子树的所有节点(包括自己)的和
{
return query(1,id[x],id[x]+Size[x]-1)%mod; //子树区间右断点为id[x]+siz[x]-1
}
int main()
{
tot=0;
cnt=0;
scanf("%d%d%d%d",&n,&q,&r,&mod);
for(int i=1;i<=n;i++) scanf("%d",&b[i]);
//原始节点的点权存在数组b里,新构造出来的点的点权存在数组a里,建线段树用的是数组a
int x,y;
for(int i=1;i<=n-1;i++)
{scanf("%d%d",&x,&y);add(x,y);add(y,x);}
dfs1(r,0,1);
dfs2(r,r);
build(1,1,n);
int ans;
while(q--)
{
int op,x,y,z;
scanf("%d",&op);
if(op==1)
{
scanf("%d%d%d",&x,&y,&z);
update1(x,y,z);
}
else if(op==2)
{
scanf("%d%d",&x,&y);
ans=query1(x,y);
printf("%d\n",ans);
}
else if(op==3)
{
scanf("%d%d",&x,&z);
update2(x,z);
}
else if(op==4)
{
scanf("%d",&x);
ans=query2(x);
printf("%d\n",ans);
}
}
return 0;
}