树上操作
Description:
有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个
操作,分为三种:
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。
Input:
第一行包含两个整数 N, M 。表示点数和操作数。接下来一行 N 个整数,表示树中节点的初始权值。接下来 N-1
行每行三个正整数 fr, to , 表示该树中存在一条边 (fr, to) 。再接下来 M 行,每行分别表示一次操作。其中
第一个数表示该操作的种类( 1-3 ) ,之后接这个操作的参数( x 或者 x a ) 。
Output:
对于每个询问操作,输出该询问的答案。答案之间用换行隔开。
Sample Input
5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3
Sample Output
6
9
13
HINT:
对于 100% 的数据, N,M<=100000 ,且所有输入数据的绝对值都不会超过 10^6
题解:
此题为一道很裸的树链剖分的题,下面来介绍一下树链剖分。
树链剖分:
树链剖分有几个基本概念:
1.重点:即相比于兄弟节点,子树节点最多的点
2.重链:重点连成的链
树链剖分就是剖分出重链,然后把重链连在一起, 再用数据结构来维护重链.
剖分重链有很多办法,这里我介绍两次DFS的方法.
由于涉及到的数组多,我先解释一下它们各自的含义:
top[i] i节点所在重链的起点
siz[i] i节点的大小(子节点个数)
fat[i] i节点的父亲节点
son[i] i节点的中儿子
in[i] i节点在dfs中第一次访问的次序
out[i] i节点在dfs中回溯的次序
void dfs_1( int x ){
siz[x] = 1;
for ( int i = head[x]; i; i = next[i])
if ( to[i] != fat[x] ){
int v = to[i];
dep[v] = dep[u] + 1;
fat[v] = x;
dfs_1( v );
siz[x] += siz[v];
}
}
第一遍DFS求出fat[], siz[], dep[]的值,很好理解,这里不过多交代。
void dfs_2( int x, int tp ){
top[x] = tq;
in[x] = ++id;
int k = 0;
for ( int i = head[x]; i; i = next[i])
if ( to[i] != fa[x] && siz[ to[i] ] > siz[k] ) k = to[i];
if ( k ){
dfs_2( k, tq );
}
for ( int i = head[x]; i; i = next[i])
if ( to[i] != fat[x] && to[i] != k ){
int v = to[i];
dfs_2( v, v );
}
out[x] = id;
}
第二次DFS根据前面算出的siz[]来找中儿子,即代码中的k,在用时间戳id求出in[]和out[].
这道题用DFS剖出重链以后,再用线段树维护in[],线段树要提供区间求和、区间加减的功能。
/**************************************************************
Problem: 4034
User: Venishel
Language: C++
Result: Accepted
Time:2760 ms
Memory:16456 kb
****************************************************************/
#include <cstdio>
#include <iostream>
#include <cstring>
using namespace std;
const int N = 1e5 + 7;
#define LL long long
struct Edge{
int nxt, to;
}e[N<<1];
int head[N], tot=0;
void addeage(int u, int v){
e[++tot].nxt=head[u], e[tot].to=v;
head[u]=tot;
}
int sz[N], mx[N], pos[N], fat[N], v[N], bl[N];
int id=0, n, m;
void dfs1(int u, int fa){
sz[u]=1;
for ( int i=head[u]; i; i=e[i].nxt ){
int v=e[i].to;
if( v==fa ) continue;
fat[v]=u;
dfs1(v, u);
sz[u]+=sz[v];
}
}
void dfs2(int u, int fa, int fq){
bl[u]=fq;
// printf("%d %d\n", u, bl[u] );
pos[u]=mx[u]=++id;
int k=0;
for ( int i=head[u]; i; i=e[i].nxt ) if( sz[k]<sz[e[i].to] && e[i].to!=fat[u] ) k=e[i].to;
if( k ){
dfs2(k, u, fq);
mx[u]=max(mx[u], mx[k]);
}
for ( int i=head[u]; i; i=e[i].nxt ){
int v=e[i].to;
if( v==fa || v==k ) continue;
dfs2(v, u, v);
mx[u]=max(mx[u], mx[v]);
}
}
struct Node{
int flg;
LL sum, add;
}tr[N<<2];
#define ls nd<<1
#define rs nd<<1|1
void pushdown(int nd, int l, int r){
if( tr[nd].flg ){
tr[nd].flg=0;
tr[ls].flg=tr[rs].flg=1;
tr[ls].add+=(LL)tr[nd].add, tr[rs].add+=(LL)tr[nd].add;
int mid=(l+r)>>1;
tr[ls].sum+=(LL)(mid-l+1)*tr[nd].add;
tr[rs].sum+=(LL)(r-mid)*tr[nd].add;
tr[nd].add=0;
return ;
}
}
void pushup(int nd){
tr[nd].sum=tr[ls].sum+tr[rs].sum;
}
void modify(int nd, int l, int r, int L, int R, int val){
if( L<=l && r<=R ){
tr[nd].flg=1;
tr[nd].add+=(LL)val;
tr[nd].sum+=(LL)(r-l+1)*val;
return;
}
pushdown(nd,l,r);
int mid=(l+r)>>1;
if( mid>=L ) modify(ls,l,mid,L,R,val);
if( mid<R ) modify(rs,mid+1,r,L,R,val);
pushup(nd);
}
LL query(int nd, int l, int r, int L, int R){
if( L<=l && r<= R ){
return tr[nd].sum;
}
pushdown(nd,l,r);
int mid=(l+r)>>1;
LL ret=0;
if ( mid>=L ) ret+=query(ls,l,mid,L,R);
if ( mid<R ) ret+=query(rs,mid+1,r,L,R);
return ret;
}
LL query(int nd){
LL ret=0;
while( bl[nd]!=1 ){
ret+=query(1,1,n,pos[bl[nd]],pos[nd]);
nd=fat[bl[nd]];
}
ret+=query(1,1,n,1,pos[nd]);
return ret;
}
/*
inline int read(){
int f=1, x=0; char ch=getchar();
while( !isdigit(ch) ) { if(ch=='-')f=-1; ch=getchar(); }
while( isdigit(ch) ) { x=x*10+ch-'0'; ch=getchar(); }
return x*f;
}*/
int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int main(){
n=read(), m=read();
for ( int i=1; i<=n; i++ ) v[i]=read();
for ( int i=1; i<n; i++ ){
int x, y;
x=read(), y=read();
addeage(x,y), addeage(y,x);
}
dfs1(1,0);
dfs2(1,0,1);
// for ( int i=1; i<=n; i++ ) printf("%d ", pos[i] );
// cout<<"\n";
// for ( int i=1; i<=n; i++ ) printf("%d ", mx[i] );
//---------------------------------------------------------------------------
for ( int i=1; i<=n; i++ ) modify(1,1,n,pos[i],pos[i],v[i]);
while( m-- ){
int opt, x, y;
opt=read();
if( opt==3 ) x=read(), printf("%lld\n", query(x) );
else if( opt==1 ) x=read(), y=read(), modify(1,1,n,pos[x],pos[x],y);
else x=read(), y=read(), modify(1,1,n,pos[x],mx[x],y);
}
}