A National Pandemic
题目描述:
国家可以表示为 n n n 个节点 n − 1 n-1 n−1 条边的图。 F ( x ) F(x) F(x) 表示节点 x x x 的疫情严重性。有以下三种修改/查询:
- 疫情在 x x x 节点爆发,严重性为 x x x,对于每个节点 y y y, F ( y ) F(y) F(y)增加 w − d i s t ( x , y ) w-dist(x,y) w−dist(x,y),其中 d i s t ( x , y ) dist(x,y) dist(x,y) 表示节点 x x x 到节点 y y y 路径上边的数量。
- 将节点 x x x 的 F ( x ) F(x) F(x)更新为 m i n ( F ( x ) , 0 ) min(F(x), 0) min(F(x),0)。
- 询问节点 x x x 的 F ( x ) F(x) F(x)
输入描述:
有多个测试用例。 输入的第一行包含一个整数
T
(
1
≤
T
≤
5
)
T(1 \leq T \leq 5)
T(1≤T≤5),表示测试用例的数量。
对于每个测试用例,第一行包含两个整数
n
,
m
(
1
≤
n
,
m
≤
5
×
1
0
4
)
n,m(1 \leq n,m \leq 5 \times 10 ^ 4)
n,m(1≤n,m≤5×104),代表城市的数量以及事件和查询的数量。 以下
n
−
1
n-1
n−1行描述了该国家/地区的所有路径,每条路径均包含两个整数
x
,
y
(
1
≤
x
,
y
≤
n
)
x,y(1 \leq x,y \leq n)
x,y(1≤x,y≤n),代表城市
x
x
x和
y
y
y之间的道路。 以下
m
m
m行描述了所有事件,每个事件均以整数
o
p
t
(
1
≤
o
p
t
≤
3
)
\mathit {opt}(1 \leq \mathit {opt} \leq 3)
opt(1≤opt≤3)开始,并且如果
o
p
t
\mathit{opt}
opt为
- 在同一行中将有两个整数 x , w ( 1 ≤ x ≤ n , 0 ≤ w ≤ 10000 ) x,w(1 \leq x \leq n,0 \leq w \leq 10000) x,w(1≤x≤n,0≤w≤10000)。 这是指上面描述中的事件1。
- 在同一行中将有一个整数 x ( 1 ≤ x ≤ n ) x(1 \leq x \leq n) x(1≤x≤n)。 这是指事件2。
- 在同一行中将有一个整数 x ( 1 ≤ x ≤ n ) x(1 \leq x \leq n) x(1≤x≤n)。 这是指您需要答复的查询。
输出描述:
每个查询输出一个整数。
样例输入:
1
5 6
1 2
1 3
2 4
2 5
1 1 5
3 4
2 1
1 2 7
3 3
3 1
样例输出:
3
9
6
思路:
首先,我们对每一个操作进行分析:
o
p
t
=
1
:
opt=1:
opt=1:
在树上求距离可以用
l
c
a
lca
lca,于是我们把
d
i
s
t
dist
dist这个函数化开,设
x
,
y
,
l
c
a
x,y,lca
x,y,lca的深度分别为
d
e
p
[
x
]
,
d
e
p
[
y
]
,
d
e
p
[
l
c
a
]
:
dep[x],dep[y],dep[lca]:
dep[x],dep[y],dep[lca]:
w
−
d
i
s
t
(
x
,
y
)
w-dist(x,y)
w−dist(x,y)
=
w
−
(
d
e
p
[
x
]
+
d
e
p
[
y
]
−
2
d
e
p
[
l
c
a
]
)
=w-(dep[x]+dep[y]-2dep[lca])
=w−(dep[x]+dep[y]−2dep[lca])
=
w
−
d
e
p
[
x
]
−
d
e
p
[
y
]
+
2
d
e
p
[
l
c
a
]
=w-dep[x]-dep[y]+2dep[lca]
=w−dep[x]−dep[y]+2dep[lca]
由此我们发现:当
o
p
t
=
=
1
opt==1
opt==1时
w
−
d
e
p
[
x
]
w-dep[x]
w−dep[x]是固定的,而对于每一个节点,
d
e
p
[
y
]
dep[y]
dep[y]也是固定的,所以我们设
A
+
=
w
−
d
e
p
[
x
]
,
B
+
+
A+=w-dep[x],B++
A+=w−dep[x],B++
A
A
A表示所有关于
x
x
x的结果,
B
B
B表示
d
e
p
[
y
]
dep[y]
dep[y]的次数
所以对于每次操作,我们都可以查询一下之前所有的结果,即:
f
(
y
)
=
A
−
B
∗
d
e
p
[
y
]
+
2
∑
(
d
e
p
[
l
c
a
]
)
f(y)=A-B*dep[y]+2\sum(dep[lca])
f(y)=A−B∗dep[y]+2∑(dep[lca])
o
p
t
=
2
:
opt=2:
opt=2:
对于这个操作,我们需要考虑正负,所以我们需要储存一下
m
i
n
(
0
,
f
(
y
)
)
min(0,f(y))
min(0,f(y)).
所以我们开一个数组存一下每次y的消除结果:
f
f
y
+
=
m
i
n
(
0
,
f
(
y
)
)
−
f
(
y
)
ff_y+=min(0,f(y))-f(y)
ffy+=min(0,f(y))−f(y)
之后我们算答案只要加上
f
f
y
ff_y
ffy就行了:
f
(
y
)
=
A
−
B
∗
d
e
p
[
y
]
+
2
∑
(
d
e
p
[
l
c
a
]
)
+
f
f
y
f(y)=A-B*dep[y]+2\sum(dep[lca])+ff_y
f(y)=A−B∗dep[y]+2∑(dep[lca])+ffy
o
p
t
=
=
3
opt==3
opt==3
输出结果即可
A C AC AC C o d e Code Code:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int MAXN=1e5+5;
int dep[MAXN],siz[MAXN],son[MAXN],fa[MAXN],top[MAXN];
struct node{int l,r,lz;ll sum;}tr[MAXN<<2];
int ld[MAXN],id[MAXN],ff[MAXN];
vector<int> vec[MAXN<<2];
int t,n,m,tot,op;
ll A,B;
void pre(int pos){
siz[pos]=1;
son[pos]=0;
for(int i=0,v;i<vec[pos].size();++i){
v=vec[pos][i];
if(v==fa[pos]) continue;
fa[v]=pos;
dep[v]=dep[pos]+1;
pre(v);
siz[pos]+=siz[v];
if(siz[v]>siz[son[pos]])
son[pos]=v;
}
}
void dfs(int x,int y){
top[x]=y;
ld[x]=++tot;
id[tot]=x;
if(son[x]) dfs(son[x],y);
for(int i=0,v;i<vec[x].size();++i){
v=vec[x][i];
if(v^son[x]&&v^fa[x])
dfs(v,v);
}
}
void build(int pos,int l,int r){
tr[pos].l=l;
tr[pos].r=r;
tr[pos].lz=0;
if(l==r){
tr[pos].sum=0;
return;
}
int mid=l+r>>1;
build(pos<<1,l,mid);
build(pos<<1|1,mid+1,r);
}
void my(int pos,int l,int r,ll val){
if(tr[pos].l==l&&tr[pos].r==r){
tr[pos].lz+=val;
return;
}
tr[pos].sum+=(ll)(r-l+1)*val;
int mid=tr[pos].l+tr[pos].r>>1;
if(r<=mid) my(pos<<1,l,r,val);
else if(l>mid) my(pos<<1|1,l,r,val);
else my(pos<<1,l,mid,val),my(pos<<1|1,mid+1,r,val);
}
void cy(int x,int y,ll val){
while(top[x]^top[y]){
if(dep[top[x]]>dep[top[y]]) swap(x,y);
my(1,ld[top[y]],ld[y],val);
y=fa[top[y]];
}
if(dep[x]>dep[y]) swap(x,y);
my(1,ld[x],ld[y],val);
}
void down(int pos){
tr[pos].sum+=tr[pos].lz*(tr[pos].r-tr[pos].l+1);
tr[pos<<1].lz+=tr[pos].lz;
tr[pos<<1|1].lz+=tr[pos].lz;
tr[pos].lz=0;
}
ll q(int pos,int l,int r){
if(tr[pos].l==l&&tr[pos].r==r)
return (ll)tr[pos].sum+tr[pos].lz*(r-l+1);
if(tr[pos].lz) down(pos);
int mid=tr[pos].l+tr[pos].r>>1;
if(r<=mid) return q(pos<<1,l,r);
else if(l>mid) return q(pos<<1|1,l,r);
else return q(pos<<1,l,mid)+q(pos<<1|1,mid+1,r);
}
ll q1(int x,int y){
ll ret=A-B*dep[x]+ff[x];
while(top[x]^top[y]){
if(dep[top[x]]>dep[top[y]]) swap(x,y);
ret+=q(1,ld[top[y]],ld[y]);
y=fa[top[y]];
}
if(dep[x]>dep[y]) swap(x,y);
return ret+q(1,ld[x],ld[y]);
}
int main(){
scanf("%d",&t);
while(t--){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i)
vec[i].clear();
for(int i=1,x,y;i<n;++i){
scanf("%d%d",&x,&y);
vec[x].push_back(y);
vec[y].push_back(x);
}
memset(tr,0,sizeof(tr));
memset(dep,0,sizeof(dep));
memset(siz,0,sizeof(siz));
memset(son,0,sizeof(son));
memset(fa,0,sizeof(fa));
memset(ff,0,sizeof(ff));
memset(ld,0,sizeof(ld));
memset(top,0,sizeof(top));
memset(id,0,sizeof(id));
A=B=tot=0;
dep[1]=1;
pre(1);
dfs(1,1);
build(1,1,tot);
while(m--){
int x,y;
scanf("%d",&op);
if(op==1){
scanf("%d%d",&x,&y);
cy(1,x,2ll);
A+=y-dep[x];
B++;
}
else{
scanf("%d",&x);
ll val=q1(x,1);
if(op==2) ff[x]+=min(0ll,val)-val;
if(op==3) printf("%lld\n",val);
}
}
}
}