自己码力好弱,,
洛谷:树链剖分模板
#include<bits/stdc++.h>
using namespace std;
const int M=1e5+10;
int n,m,r,cnt,p,tot,d[M],head[M<<1],siz[M],dep[M],son[M],fa[M],top[M],id[M],rk[M];
struct Edge{ int nex,to;}e[M<<1];
struct Node{int l,r,sum,tag;}t[M<<4];
void add(int x,int to){
e[cnt].to=to;
e[cnt].nex=head[x];
head[x]=cnt++;
}
void dfs1(int x,int f){
fa[x]=f;
siz[x]=1;
dep[x]=dep[f]+1;
for(int to,i=head[x];~i;i=e[i].nex){
to=e[i].to;
if(to==f)continue;
dfs1(to,x);
siz[x]+=siz[to];
if(siz[son[x]]<siz[to])son[x]=to;
}
}
void dfs2(int x,int tp){
top[x]=tp;
id[x]=++tot;
rk[tot]=x;
if(son[x])dfs2(son[x],tp);
for(int to,i=head[x];~i;i=e[i].nex){
to=e[i].to;
if(to!=fa[x]&&to!=son[x])
dfs2(to,to);
}
}
void build(int k,int l,int r){
t[k].l=l;t[k].r=r;t[k].tag=0;
if(l==r){
t[k].sum=d[rk[l]]%p;
return ;
}
int mid=(l+r)>>1;
build(k*2,l,mid);
build(k*2+1,mid+1,r);
t[k].sum=(t[k*2].sum+t[k*2+1].sum)%p;
}
void down(int k){
if(t[k].tag){
t[k*2].tag=(t[k].tag+t[k*2].tag)%p;
t[k*2].sum=(t[k*2].sum+(t[k*2].r-t[k*2].l+1)*t[k].tag%p)%p;
t[k*2+1].tag=(t[k].tag+t[k*2+1].tag)%p;
t[k*2+1].sum=(t[k*2+1].sum+(t[k*2+1].r-t[k*2+1].l+1)*t[k].tag%p)%p;
t[k].tag=0;
}
}
void update(int k,int l,int r,int w){
if(l<=t[k].l&&t[k].r<=r){
t[k].tag=(t[k].tag+w)%p;
t[k].sum=(t[k].sum+(t[k].r-t[k].l+1)*w)%p;
return ;
}
down(k);
int mid=(t[k].l+t[k].r)>>1;
if(l<=mid)update(k*2,l,r,w);
if(mid<r)update(k*2+1,l,r,w);
t[k].sum=(t[k*2].sum+t[k*2+1].sum)%p;
}
void updates(int x,int y,int z){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
update(1,id[top[x]],id[x],z);
x=fa[top[x]];
}
if(id[x]>id[y])swap(x,y);
update(1,id[x],id[y],z);
}
int ask(int k,int l,int r){
if(l<=t[k].l&&t[k].r<=r)return t[k].sum;
down(k);
int mid=(t[k].l+t[k].r)>>1;
int res=0;
if(l<=mid)res+=ask(k*2,l,r);
if(mid<r)res+=ask(k*2+1,l,r);
return res%p;
}
int asks(int x,int y){
int res=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
res=(ask(1,id[top[x]],id[x])+res)%p;
x=fa[top[x]];
}
if(id[x]>id[y])swap(x,y);//这里是一个错误
res=(ask(1,id[x],id[y])+res)%p;
return res%p;
}
int main(){
memset(head,-1,sizeof(head));
cnt=1;
tot=0;
scanf("%d%d%d%d",&n,&m,&r,&p);
for(int i=1;i<=n;i++)scanf("%d",d+i);
for(int x,y,i=1;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dep[r]=1;
dfs1(r,0);
dfs2(r,r);
build(1,1,n);
int op,x,y,z;
while(m--){
scanf("%d",&op);
switch(op){
case 1:
scanf("%d%d%d",&x,&y,&z);
updates(x,y,z);
break;
case 2:
scanf("%d%d",&x,&y);
printf("%d\n",asks(x,y));
break;
case 3:
scanf("%d%d",&x,&z);
update(1,id[x],id[x]+siz[x]-1,z);
break;
case 4:
scanf("%d",&x);
printf("%d\n",ask(1,id[x],id[x]+siz[x]-1));
break;
}
}
return 0;
}