题目链接
http://acm.zzuli.edu.cn/problem.php?id=2520
题意
有一棵有根树,根结点编号为
1
1
1,编号为
i
i
i 的结点的权值为
w
i
w_i
wi,现在定义结点
(
u
,
v
)
(u,v)
(u,v) 为“大小接近的点对”,当且仅当满足:
(1)
u
u
u 是
v
v
v 的祖先结点(
v
v
v 可以等于
u
u
u)
(2)
∣
w
u
−
w
v
∣
≤
k
|w_u-w_v|\leq k
∣wu−wv∣≤k
现在对于树上每个结点
i
i
i,需要计算以其为根的子树中“大小接近的点对”的个数。
思路
首先不难得出,当以
u
u
u 为根时,设
v
v
v 为其子树的某一结点,则需要满足的是:
∣
w
u
−
w
v
∣
≤
k
−
k
≤
w
v
−
w
u
≤
k
w
u
−
k
≤
w
v
≤
w
u
+
k
|w_u-w_v|\leq k\\ -k \leq w_v-w_u\leq k \\ w_u-k \leq w_v \leq w_u+k
∣wu−wv∣≤k−k≤wv−wu≤kwu−k≤wv≤wu+k
所以当结点
v
v
v 的权值满足在这个范围内时,就说明
(
u
,
v
)
(u,v)
(u,v) 是“大小接近的点对”。该如何统计呢?在一个子树中,根是确定的,所以,想到 dfs 序,一棵子树在 dfs 序中是连续的,那问题就转化为:在树的
d
f
s
dfs
dfs 序列中,找很多区间值在
[
w
u
−
k
,
w
u
+
k
]
[w_u-k,w_u+k]
[wu−k,wu+k] 范围内的数的个数。
我们不妨用主席树尝试解决这一问题,由于主席树一般都是权值线段树的可持久化,所以题目中
w
i
w_i
wi 的范围和
k
k
k 的范围都较大,显然不行,所以需要先离散化,所以首先第一步就是,把每个结点的
w
i
w_i
wi ,
w
i
−
k
w_i-k
wi−k ,
w
i
+
k
w_i+k
wi+k 都放在一起离散化一下,
w
i
−
k
w_i-k
wi−k 如果为负值就略过。然后根据时间顺序建立可持久化权值线段树,也就是按照时间顺序,每个时刻往主席树里插入一个 dfs 序的当前时间的结点权值。之后针对每个结点
i
i
i,根据主席树的前缀可减性,查询在
n
n
n 个历史版本中时间段在
[
i
n
i
−
1
,
i
n
i
+
s
i
z
i
−
1
]
[in_i-1,in_i+siz_i-1]
[ini−1,ini+sizi−1] 内满足 大小在离散化后
[
w
i
−
k
,
w
i
+
k
]
[w_i-k,w_i+k]
[wi−k,wi+k] 的数的个数,现在统计出的以每个结点为根,子树中与其“大小接近的点对”,所以还需要 dfs 一下,自下而上累加求和,才能得出每个结点为根的子树中所有“大小接近的点对”。
代码
#include <bits/stdc++.h>
#define ft first
#define sd second
#define pb push_back
#define endl '\n'
#define nul string::npos
#define sz(ss) (int)ss.size()
using namespace std;
typedef long long ll;
typedef double db;
typedef pair<int,int> P;
const int maxn=1e5+7;
const int inf=0x3f3f3f3f;
const int mod=998244353;
int n,k;
vector<int> g[maxn];
int w[maxn];
int in[maxn],siz[maxn],t,root[maxn],cnt;
int a[maxn*3],b[maxn*3];
int len,szz;
int x[maxn];
ll dp[maxn];
struct node{
int l,r,sum;
}no[maxn*40];
void dfs(int u){
in[u]=++t;//u在dfs序中的时刻编号
x[t]=u;//dfs序中t时刻的结点为u
siz[u]=1;//u的子树大小
for(int i=0;i<sz(g[u]);i++){
int v=g[u][i];
dfs(v);
siz[u]+=siz[v];
}
}
void _dfs(int u){
for(int i=0;i<sz(g[u]);i++){
int v=g[u][i];
_dfs(v);
dp[u]+=dp[v];
}
}
void lis(){//离散化
for(int i=1;i<=n;i++){
if(w[i]-k>=1){
a[len]=w[i]-k;
b[len]=a[len];
len++;
}
a[len]=w[i]+k;
b[len]=a[len];
len++;
}
sort(b+1,b+len);
szz=unique(b+1,b+len)-(b+1);
for(int i=1;i<len;i++){
a[i]=lower_bound(b+1,b+1+szz,a[i])-b;
}
}
void update(int l,int r,int pre,int &now,int x){
now=++cnt;no[now]=no[pre];
no[now].sum++;
if(l==r)return;
int m=(l+r)>>1;
if(x<=m)update(l,m,no[pre].l,no[now].l,x);
else update(m+1,r,no[pre].r,no[now].r,x);
}
int query(int l,int r,int pre,int now,int L,int R){
if(L<=l&&r<=R)return no[now].sum-no[pre].sum;
int m=(l+r)>>1,res=0;
if(L<=m)res=query(l,m,no[pre].l,no[now].l,L,R);
if(R>m)res+=query(m+1,r,no[pre].r,no[now].r,L,R);
return res;
}
int main(){
cin>>n>>k;
for(int i=1;i<=n;i++){
scanf("%d",&w[i]);
a[i]=w[i];
b[i]=w[i];
}
for(int i=1,u;i<n;i++){
scanf("%d",&u);
g[u].pb(i+1);
}
dfs(1);
len=n+1;
lis();
for(int i=1;i<=n;i++){
update(1,szz,root[i-1],root[i],a[x[i]]);
}
len=n+1;
for(int i=1;i<=n;i++){
if(w[i]-k>=1){
dp[i]=query(1,szz,root[in[i]-1],root[in[i]+siz[i]-1],a[len],a[len+1]);
len+=2;
}
else {
dp[i]=query(1,szz,root[in[i]-1],root[in[i]+siz[i]-1],1,a[len++]);
}
}
_dfs(1);
//会爆int
for(int i=1;i<=n;i++)printf("%lld\n",dp[i]);
return 0;
}