Description
submit
有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个
操作,分为三种:
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。
Input
第一行包含两个整数 N, M 。表示点数和操作数。接下来一行 N 个整数,表示树中节点的初始权值。接下来 N-1
行每行三个正整数 fr, to , 表示该树中存在一条边 (fr, to) 。再接下来 M 行,每行分别表示一次操作。其中
第一个数表示该操作的种类( 1-3 ) ,之后接这个操作的参数( x 或者 x a ) 。
Output
对于每个询问操作,输出该询问的答案。答案之间用换行隔开。
Sample Input
5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3
Sample Output
6
9
13
HINT
对于 100% 的数据, N,M<=100000 ,且所有输入数据的绝对值都不会超过 10^6 。
思路
一个裸的树链剖分
首先连边,然后跑两边dfs
第一遍求每个节点的父节点和子树大小
第二遍重新编号每个节点并求出重链和每个点子树中的最大值
然后建一个空树,主要是为了找到结构体中每个区间的左右端点
再把每个重新编号的元素添加进树
最后根据每个要求求解即可
代码
#include <bits/stdc++.h>
#define ls (rt<<1)
#define rs (rt<<1|1)
#define mid ((tr[rt].l+tr[rt].r)>>1)
#define N 100001
#define ll long long
using namespace std;
inline int read(){
int ret=0,f=1;char c=getchar();
for(;!isdigit(c);c=getchar())if(c=='-')f=-1;
for(;isdigit(c);c=getchar())ret=ret*10+c-'0';
return ret*f;
}
int n,m,v[N],pos[N],top[N],bl[N],pp=0,he[N];
int siz[N],ma[N],cnt=0,fa[N];
struct pppp{int l,r;ll sum,tag;}tr[N<<2];
struct derpp{int to,nxt;}a[N<<2];
inline void add(int x,int y){
a[++pp]=(derpp){y,he[x]};he[x]=pp;
a[++pp]=(derpp){x,he[y]};he[y]=pp;
}
void build(int rt,int l,int r){
tr[rt].l=l;tr[rt].r=r;
if(l==r)return ;
build(ls,l,mid);
build(rs,mid+1,r);
}
void dfs(int x){
siz[x]=1;
for(int i=he[x];~i;i=a[i].nxt){
int v=a[i].to;
if(v!=fa[x]){
fa[v]=x;
dfs(v);
siz[x]+=siz[v];
ma[x]=max(ma[x],ma[v]);
}
}
}
void dfs2(int x,int father){
bl[x]=father;pos[x]=ma[x]=++cnt;
int k=0,v;
for(int i=he[x];~i;i=a[i].nxt){
v=a[i].to;
if(v!=fa[x]&&siz[v]>siz[k])k=v;
}
if(k){
dfs2(k,father);
ma[x]=max(ma[x],ma[k]);
}
for(int i=he[x];~i;i=a[i].nxt){
v=a[i].to;
if(v!=fa[x]&&v!=k){
dfs2(v,v);
ma[x]=max(ma[x],ma[v]);
}
}
}
void pushdown(int rt){
if(tr[rt].l==tr[rt].r)return ;
tr[ls].tag+=tr[rt].tag;tr[rs].tag+=tr[rt].tag;
tr[ls].sum+=tr[rt].tag*(mid-tr[rt].l+1);
tr[rs].sum+=tr[rt].tag*(tr[rt].r-mid);
tr[rt].tag=0;
}
void add(int rt,int x,int y,ll val){
if(tr[rt].tag)pushdown(rt);
if(tr[rt].l==x&&tr[rt].r==y){tr[rt].tag+=val;tr[rt].sum+=(tr[rt].r-tr[rt].l+1)*val;return ;}
if(x<=mid)add(ls,x,min(mid,y),val);
if(y>=mid+1)add(rs,max(x,mid+1),y,val);
tr[rt].sum=tr[ls].sum+tr[rs].sum;
}
ll query(int rt,int x,int y){
if(tr[rt].tag)pushdown(rt);
if(tr[rt].l==x&&tr[rt].r==y)return tr[rt].sum;
ll ans=0;
if(x<=mid)ans+=query(ls,x,min(mid,y));
if(y>=mid+1)ans+=query(rs,max(mid+1,x),y);
return ans;
}
ll query(int x){
ll ans=0;
while(bl[x]!=1){
ans+=query(1,pos[bl[x]],pos[x]);
x=fa[bl[x]];
}
ans+=query(1,1,pos[x]);
return ans;
}
int main(){int x,y,opt;
memset(he,-1,sizeof(he));
n=read();m=read();
for(int i=1;i<=n;++i)v[i]=read();
for(int i=1;i<n;++i){
x=read();y=read();
add(x,y);
}
dfs(1);dfs2(1,1);
build(1,1,n);
for(int i=1;i<=n;++i)add(1,pos[i],pos[i],v[i]);
while(m--){
opt=read();x=read();
if(opt==1){
y=read();
add(1,pos[x],pos[x],y);
}
else if(opt==2){
y=read();
add(1,pos[x],ma[x],y);
}
else{
printf("%lld\n",query(x));
}
}
return 0;
}