题目描述
给定一个 n n n个节点的的有根树,编号依次为 1 1 1到 n n n,其中 1 1 1号节点为根节点。每个点有一个权值 v i v_i vi。
你需要将这棵树转化为一个大根堆。确切地说,你需要找到尽可能多的节点,满足大根堆的性质:对于任意两个点 i , j i,j i,j,如果 i i i在树上是 j j j的祖先,那么 v i > v j v_i>v_j vi>vj。
请计算可选的最多的点数,注意这些点不必形成这棵树的一个连通子树。
输入样例
6
3 0
1 1
2 1
3 1
4 1
5 1
输出样例
5
数据范围
- 1 ≤ n ≤ 2 × 1 0 5 1\leq n\leq 2\times 10^5 1≤n≤2×105
- 0 ≤ v i ≤ 1 0 9 , 1 ≤ p i < i 0\leq v_i\leq 10^9,1\leq p_i< i 0≤vi≤109,1≤pi<i
题解
前置知识:线段树合并
这道题其实就是让我们求树上的最长上升子序列。
首先,我们先考虑一个序列的最长上升子序列。假设已经求出了前 i i i个数的最长上升子序列,在加入第 i + 1 i+1 i+1个数时,在当前数列中找到第一个大于等于它的数。如果有,则用新加入的数替换;否则将新加入的数放在队尾。
树上呢?也一样。用权值线段树维护每一个点,线段树的 s i z siz siz值就是该点所在子树中能选的最多的点的数量。
对于每个节点,先遍历其子树,然后将该节点的儿子节点的线段树合并到自己的线段树上。最后,在自己的线段树上找到第一个大于等于该节点的权值的位置。如果有,在权值线段树上这个位置减一,在该节点的权值所在位置加一;否则直接在该节点权值所在的位置加一。
最后,根节点的线段树的大小就是答案。时间复杂度为 O ( n log n ) O(n\log n) O(nlogn)。
code
#include<bits/stdc++.h>
#define N 200000
using namespace std;
int n,x,tot=0,f,v[200005],num[200005],d[400005],l[400005],r[400005],rt[200005];
struct node{
int lc,rc,s;
}tr[10000005];
void add(int xx,int yy){
l[++tot]=r[xx];d[tot]=yy;r[xx]=tot;
}
void pt(int &k,int l,int r,int z){
if(!k) k=++tot;
if(l==r){
++tr[k].s;return;
}
int mid=(l+r)/2;
if(z<=mid) pt(tr[k].lc,l,mid,z);
else pt(tr[k].rc,mid+1,r,z);
tr[k].s=tr[tr[k].lc].s+tr[tr[k].rc].s;
}
void merge(int &r1,int r2,int l,int r){
if(!r1||!r2){
r1=r1+r2;return;
}
if(l==r){
tr[r1].s+=tr[r2].s;return;
}
int mid=(l+r)/2;
merge(tr[r1].lc,tr[r2].lc,l,mid);
merge(tr[r1].rc,tr[r2].rc,mid+1,r);
tr[r1].s=tr[tr[r1].lc].s+tr[tr[r1].rc].s;
}
void dele(int k,int l,int r,int z){
if(!k) return;
if(l==r){
if(tr[k].s>0){
--tr[k].s;f=1;
}
return;
}
int mid=(l+r)/2;
if(z<=mid){
dele(tr[k].lc,l,mid,z);
if(!f) dele(tr[k].rc,mid+1,r,z);
}
else dele(tr[k].rc,mid+1,r,z);
tr[k].s=tr[tr[k].lc].s+tr[tr[k].rc].s;
}
void dfs(int u){
for(int i=r[u];i;i=l[i]){
dfs(d[i]);
merge(rt[u],rt[d[i]],1,N);
}
f=0;
dele(rt[u],1,N,v[u]);
pt(rt[u],1,N,v[u]);
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d%d",&v[i],&x);num[i]=v[i];
if(x) add(x,i);
}
tot=0;
sort(num+1,num+n+1);
int gs=unique(num+1,num+n+1)-num-1;
for(int i=1;i<=n;i++){
v[i]=lower_bound(num+1,num+gs+1,v[i])-num;
}
dfs(1);
printf("%d",tr[rt[1]].s);
return 0;
}
507





