题目描述
给定一个 n n n个节点的的有根树,编号依次为 1 1 1到 n n n,其中 1 1 1号节点为根节点。每个点有一个权值 v i v_i vi。
你需要将这棵树转化为一个大根堆。确切地说,你需要找到尽可能多的节点,满足大根堆的性质:对于任意两个点 i , j i,j i,j,如果 i i i在树上是 j j j的祖先,那么 v i > v j v_i>v_j vi>vj。
请计算可选的最多的点数,注意这些点不必形成这棵树的一个连通子树。
输入样例
6
3 0
1 1
2 1
3 1
4 1
5 1
输出样例
5
数据范围
- 1 ≤ n ≤ 2 × 1 0 5 1\leq n\leq 2\times 10^5 1≤n≤2×105
- 0 ≤ v i ≤ 1 0 9 , 1 ≤ p i < i 0\leq v_i\leq 10^9,1\leq p_i< i 0≤vi≤109,1≤pi<i
题解
之前写过一篇类似的:BZOJ4919 大根堆(线段树合并)
不过下面这种方法不用线段树合并,只需要用multiset。
首先,我们先考虑一个序列的最长上升子序列。假设已经求出了前 i i i个数的最长上升子序列,在加入第 i + 1 i+1 i+1个数时,在当前数列中找到第一个大于等于它的数。如果有,则用新加入的数替换;否则将新加入的数放在队尾。
树上呢?也一样。对于每个节点,先遍历其子树,然后将该节点的儿子节点的最长上升子序列合并到自己的序列上。最后,在自己的序列上找到第一个大于等于该节点的权值的位置。如果有,将其在序列上删去,再把该节点的权值加在序列中;否则直接把该节点权值加在序列中。
这样做的话,每个节点都最多需要 O ( n ) O(n) O(n)的时间复杂度来合并,看起来是 O ( n 2 ) O(n^2) O(n2)的,但是我们可以用一种特殊的方法来保证其时间复杂度为 O ( n log n ) O(n\log n) O(nlogn)。
我们将每个点的各个儿子中子树的节点数量最多的儿子称为重儿子,其余为轻儿子。重儿子与父亲的连边称为重边,其余边称为轻边,重边连成的链称为重链。那么每次先遍历重儿子,再将当前节点的multiset序列和其重儿子的multiset序列互换。也就是说,重链上的点在重链中只会被放入序列一次。在这种情况下,我们考虑每个点被放入序列了多少次。每个点到根节点的路径上最多只会有 log n \log n logn条轻边(每从轻儿子沿轻边向上,子树大小至少为原来的两倍),那么,每个点总共只会被加 O ( log n ) O(\log n) O(logn)次,所有节点总共最多会被加入序列 O ( n log n ) O(n\log n) O(nlogn)次。而处理一条重边的时间复杂度为 O ( 1 ) O(1) O(1),所以处理所有重边的总时间复杂度为 O ( n ) O(n) O(n)。
最后,根节点的序列的长度就是答案。因为可能有权值相等的节点,所以要用multiset而不能用set。因为用了multiset,所以时间复杂度为 O ( n log 2 n ) O(n\log^2 n) O(nlog2n)。
code
#include<bits/stdc++.h>
using namespace std;
int n,tot=0,a[200005],d[200005],l[200005],r[200005],siz[200005],son[200005];
multiset<int>s[200005];
void add(int xx,int yy){
l[++tot]=r[xx];d[tot]=yy;r[xx]=tot;
}
void dfs1(int u){
siz[u]=1;
for(int i=r[u];i;i=l[i]){
dfs1(d[i]);
siz[u]+=siz[d[i]];
if(siz[d[i]]>siz[son[u]]) son[u]=d[i];
}
}
void dfs2(int u){
multiset<int>::iterator it;
if(son[u]){
dfs2(son[u]);
swap(s[u],s[son[u]]);
}
for(int i=r[u];i;i=l[i]){
if(d[i]==son[u]) continue;
dfs2(d[i]);
for(it=s[d[i]].begin();it!=s[d[i]].end();++it){
s[u].insert(*it);
}
s[d[i]].clear();
}
it=s[u].lower_bound(a[u]);
if(it!=s[u].end()) s[u].erase(it);
s[u].insert(a[u]);
}
int main()
{
scanf("%d",&n);
for(int i=1,x;i<=n;i++){
scanf("%d%d",&a[i],&x);add(x,i);
}
dfs1(1);
dfs2(1);
printf("%d",s[1].size());
return 0;
}