题目描述
如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和
输入输出格式
输入格式:
第一行包含4个正整数N、M、R、P,分别表示树的结点个数、操作个数、根节点序号和取模数(即所有的输出结果均对此取模)。
接下来一行包含N个非负整数,分别依次表示各个节点上初始的数值。
接下来N-1行每行包含两个整数x、y,表示点x和点y之间连有一条边(保证无环且连通)
接下来M行每行包含若干个正整数,每行表示一个操作,格式如下:
操作1: 1 x y z
操作2: 2 x y
操作3: 3 x z
操作4: 4 x
输出格式:
输出包含若干行,分别依次表示每个操作2或操作4所得的结果(对P取模)
#include<cstdio>
#include<string>
#define lson o<<1
#define rson o<<1|1
using namespace std;
const int maxn=1000005;
int N,M,R,P,cnt;
int tot,to[maxn],nxt[maxn],lnk[maxn];
int sumv[maxn<<2],tag[maxn<<2],a[maxn];
int dep[maxn],fa[maxn],son[maxn],siz[maxn],top[maxn],w[maxn],id[maxn];
inline int read() {
int ret=0,f=1;char ch=getchar();
for (; !isdigit(ch); ch=getchar()) if (ch=='-') f=-f;
for (; isdigit(ch); ch=getchar()) ret=ret*10+ch-48;
return ret*f;
}
inline void add_edge(int x,int y) {
to[++tot]=y,nxt[tot]=lnk[x],lnk[x]=tot;
to[++tot]=x,nxt[tot]=lnk[y],lnk[y]=tot;
}
inline void dfs1(int now,int fat) {
dep[now]=dep[fat]+1,fa[now]=fat,siz[now]=1;
for (int k=lnk[now]; k; k=nxt[k]) if (to[k]^fat) {
dfs1(to[k],now);
siz[now]+=siz[to[k]];
if (siz[to[k]]>siz[son[now]]) son[now]=to[k];
}
}
inline void dfs2(int now,int topf) {
id[now]=++cnt,a[cnt]=w[now],top[now]=topf;
if (son[now]==0) return;
dfs2(son[now],topf);
for (int k=lnk[now]; k; k=nxt[k])
if (to[k]^fa[now]&&to[k]^son[now])
dfs2(to[k],to[k]);
}
inline void pushup(int o) {sumv[o]=sumv[lson]+sumv[rson];}
inline void pushdown(int o,int l,int r) {
if (tag[o]==0) return;
int mid=l+(r-l>>1);
(sumv[lson]+=(tag[o]*(mid-l+1))%P)%=P,(tag[lson]+=tag[o])%=P;
(sumv[rson]+=(tag[o]*(r - mid))%P)%=P,(tag[rson]+=tag[o])%=P;
tag[o]=0;
}
inline void build(int o,int l,int r) {
if (l==r) {sumv[o]=a[l];return;}
int mid=l+(r-l>>1);
build(lson,l,mid),build(rson,mid+1,r);
pushup(o);
}
inline void updata(int o,int l,int r,int ql,int qr,int dat) {
if (l>qr||r<ql) return;
if (l>=ql&&r<=qr) {(sumv[o]+=(dat*(r-l+1))%P)%=P;(tag[o]+=dat)%=P;return;}
int mid=l+(r-l>>1);
pushdown(o,l,r);
updata(lson,l,mid,ql,qr,dat),updata(rson,mid+1,r,ql,qr,dat);
pushup(o);
}
inline int query(int o,int l,int r,int ql,int qr) {
if (l>qr||r<ql) return 0;
if (l>=ql&&r<=qr) return sumv[o];
int mid=l+(r-l>>1);
pushdown(o,l,r);
return (query(lson,l,mid,ql,qr)%P+query(rson,mid+1,r,ql,qr)%P)%P;
}
inline void doit1(int z,int y,int x) {
while (top[x]^top[y]) {
if (dep[top[x]]<dep[top[y]]) swap(x,y);
updata(1,1,N,id[top[x]],id[x],z);
x=fa[top[x]];
}
if (id[x]>id[y]) swap(x,y);
updata(1,1,N,id[x],id[y],z);
}
inline int doit2(int y,int x) {
int ret=0;
while (top[x]^top[y]) {
if (dep[top[x]]<dep[top[y]]) swap(x,y);
(ret+=query(1,1,N,id[top[x]],id[x]))%=P;
x=fa[top[x]];
}
if (id[x]>id[y]) swap(x,y);
printf("%d\n",(ret+query(1,1,N,id[x],id[y]))%P);
}
inline void doit3(int z,int x) {updata(1,1,N,id[x],id[x]+siz[x]-1,z);}
inline int doit4(int x) {printf("%d\n",query(1,1,N,id[x],id[x]+siz[x]-1)%P);}
int main() {
N=read(),M=read(),R=read(),P=read();
for (int i=1; i<=N; i++) w[i]=read();
for (int i=1; i<N; i++) add_edge(read(),read());
dfs1(R,0),dfs2(R,R),build(1,1,N);
for (int i=1; i<=M; i++) {
int f=read();
if (f==1) doit1(read(),read(),read());
if (f==2) doit2(read(),read());
if (f==3) doit3(read(),read());
if (f==4) doit4(read());
}
return 0;
}