雅礼集训6.20T1
题目描述

题解

代码
#include<bits/stdc++.h>
#define int long long
#define M 600009
using namespace std;
int read(){
int f=1,re=0;
char ch;
for(ch=getchar();!isdigit(ch)&&ch!='-';ch=getchar());
if(ch=='-'){f=-1,ch=getchar();}
for(;isdigit(ch);ch=getchar()) re=(re<<3)+(re<<1)+ch-'0';
return re*f;
}
struct zb{
int l,r,sum,add;
}tr[M];
int n,q,a[M],idx[M],num[M],top[M],f[M],dep[M],rt,cnt,son[M],nxt[M*2],first[M],to[M*2],tot,size[M];
void add(int x,int y){
nxt[++tot]=first[x];
first[x]=tot;
to[tot]=y;
}
void build(int k,int l,int r){
tr[k].l=l;tr[k].r=r;
if(l==r){
tr[k].sum=a[idx[l]];
return;
}int mid=(l+r)>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
tr[k].sum=tr[k<<1].sum+tr[k<<1|1].sum;
}
void update(int k,int l,int r,int val){
tr[k].add+=val;
tr[k].sum+=val*(r-l+1);
}
void pushdown(int k,int l,int r,int mid){
if(!tr[k].add) return;
update(k<<1,l,mid,tr[k].add);
update(k<<1|1,mid+1,r,tr[k].add);
tr[k].add=0;
}
void modify(int k,int l,int r,int val){
if(tr[k].l>=l&&tr[k].r<=r) return update(k,tr[k].l,tr[k].r,val);
int mid=(tr[k].l+tr[k].r)>>1;
pushdown(k,tr[k].l,tr[k].r,mid);
if(l<=mid) modify(k<<1,l,r,val);
if(r>mid) modify(k<<1|1,l,r,val);
tr[k].sum=(tr[k<<1].sum+tr[k<<1|1].sum);
}
int solve(int k,int l,int r){
if(tr[k].l>=l&&tr[k].r<=r) return tr[k].sum;
int mid=(tr[k].l+tr[k].r)>>1,ret=0;
pushdown(k,tr[k].l,tr[k].r,mid);
if(l<=mid) ret=(ret+solve(k<<1,l,r));
if(r>mid) ret=(ret+solve(k<<1|1,l,r));
return ret;
}
void dfs1(int r,int fa){
dep[r]=dep[fa]+1;
size[r]=1;
f[r]=fa;
for(int i=first[r];i;i=nxt[i]){
int u=to[i];
if(u==fa) continue;
dfs1(u,r);
size[r]+=size[u];
if(size[u]>size[son[r]]) son[r]=u;
}
}
void dfs2(int r,int tp){
top[r]=tp;
num[r]=++cnt;
idx[num[r]]=r;
if(son[r]) dfs2(son[r],tp);
for(int i=first[r];i;i=nxt[i]){
int u=to[i];
if(!num[u]) dfs2(u,u);
}
}
int getmax(int x,int y,int z){
int maxn;
if(dep[x]>dep[y]) maxn=x;
else maxn=y;
if(dep[maxn]<dep[z]) maxn=z;
return maxn;
}
int getlca(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=f[top[x]];
}if(dep[x]<dep[y]) return x;
return y;
}
int getson(int x,int y){
while(top[x]!=top[y]){
if(y==f[top[x]]) return top[x];
x=f[top[x]];
}
while(x!=y){
if(dep[x]==dep[y]+1) return x;
x=f[x];
}
}
void tree_update(int x,int y){modify(1,num[x],num[x]+size[x]-1,y);}
int tree_query(int x){return solve(1,num[x],num[x]+size[x]-1);}
signed main(){
n=read(),q=read();rt=1;
for(int i=1;i<=n;i++) a[i]=read();
for(int i=1;i<n;i++){
int x=read(),y=read();
add(y,x),add(x,y);
}dfs1(1,0),dfs2(1,1),build(1,1,n);
for(int i=1;i<=q;i++){
int opt=read(),x,y,z;
if(opt==1) rt=read();
if(opt==2){
x=read(),y=read(),z=read();
int lca1=getlca(x,y);
int lca2=getlca(lca1,rt);
if(lca2==lca1){
tree_update(1,z);
int v=getlca(x,rt),w=getlca(y,rt);
if(v!=rt&&w!=rt){
if(dep[v]<dep[w]) swap(v,w);
v=getson(rt,v);
tree_update(v,-z);
}
}else tree_update(lca1,z);
}if(opt==3){
int lca=getlca(x=read(),rt);
if(lca==x){
int ans=tree_query(1);
if(x!=rt) x=getson(rt,x),ans-=tree_query(x);
printf("%lld\n",ans);
}else printf("%lld\n",tree_query(x));
}
}return 0;
}