Atcoder传送门
题目大意
给你一棵有 N N 个节点的有根树, 根节点为。给每个节点一个权值 Vi V i 为 0 0 或。现在要求将树的标号对应在一个序列 a1,a2,...,aN a 1 , a 2 , . . . , a N 上,并且 ai a i 的右侧没有 i i 的祖先节点。 求这样的序列中逆序对最少的个数。
输入输出格式
输入格式
第一行一个整数。
第二行 N−1 N − 1 个整数, 分别代表 fat[2],fat[3],fat[4],...,fat[N] f a t [ 2 ] , f a t [ 3 ] , f a t [ 4 ] , . . . , f a t [ N ] 。
第三行 N N 个整数, 代表。
输出格式
一行一个整数表示逆序对最少的个数。
输入输出样例
输入样例#1
6
1 1 2 3 3
0 1 1 0 0 0
输出样例#1
4
输入样例#2
1
0
输出样例#2
0
输入样例#3
15
1 2 3 2 5 6 2 2 9 10 1 12 13 12
1 1 1 0 1 1 0 0 1 0 0 1 1 0 0
输出样例#3
31
数据范围
- 1≤N≤2×105 1 ≤ N ≤ 2 × 10 5
- 1≤fat[i]<i 1 ≤ f a t [ i ] < i (2≤i≤N) ( 2 ≤ i ≤ N )
- 0≤Vi≤1(1≤i≤N) 0 ≤ V i ≤ 1 ( 1 ≤ i ≤ N )
解题分析
此题真是神题…
考虑一开始每个节点为单独的一个点, 我们将它们合并的过程。
对于一个节点 A A 及其合并得到的一些点构成的集合,记 f(SA) f ( S A ) 为其中0的个数, g(SA) g ( S A ) 为其中1的个数。
我们可以发现, 每次合并 f(SA)g(SA) f ( S A ) g ( S A ) 值最大的点到 fat[A] f a t [ A ] 中最优。
具体而言, 我们先不考虑 SA S A 和 Sfat[A] S f a t [ A ] 内的贡献(我们在合并它们内部的点的时候已经考虑过了), 它们之间对答案的贡献是 f(Sfat[A])×g(SA) f ( S f a t [ A ] ) × g ( S A ) 。如果 fat[A] f a t [ A ] 有另一个儿子 B B , 而, 那么对于 fat[A] f a t [ A ] 而言它们的排列顺序不影响逆序对个数。 如果排列顺序为 Sfat[A],SA,SB S f a t [ A ] , S A , S B , SA S A 与 SB S B 间的贡献为 f(SB)×g(SA)<g(SB)×f(SA) f ( S B ) × g ( S A ) < g ( S B ) × f ( S A ) ,所以我们应该将 f(SA)g(SA) f ( S A ) g ( S A ) 值最大的点优先合并到 fat[A] f a t [ A ] 中。
剩下的就是用 priority queue p r i o r i t y q u e u e 或者 set s e t 维护这个值, 用并查集维护合并状态即可。 注意合并的时候要将 fat[A] f a t [ A ] 也先 erase e r a s e 掉, 因为合并 f(SA) f ( S A ) 和 g(SA) g ( S A ) 的时候 set s e t 中 Sfat[A] S f a t [ A ] 的位置是没有变的。
代码如下:
#include <cstdio>
#include <algorithm>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <cctype>
#include <set>
#define R register
#define IN inline
#define W while
#define IN inline
#define gc getchar()
#define MX 200500
template <class T>
IN void in(T &x)
{
x = 0; R char c = gc;
W (!isdigit(c)) c = gc;
W (isdigit(c))
x = (x << 1) + (x << 3) + c - 48, c = gc;
}
int col[MX][2], fat[MX], bel[MX], dot;
struct Vertex
{int id;};
IN bool operator < (const Vertex &x, const Vertex &y)
{
int x0 = col[x.id][0], x1 = col[x.id][1], y0 = col[y.id][0], y1 = col[y.id][1];
return 1ll * x0 * y1 > 1ll * x1 * y0 ||
(1ll * x0 * y1 == 1ll * x1 * y0 && x.id > y.id);
}
std::set <Vertex> st;
long long ans;
namespace DSU
{
int find(R int now) {return bel[now] == now ? now : bel[now] = find(bel[now]);}
IN void combine(const int &from, const int &to) {bel[from] = find(to); col[to][0] += col[from][0], col[to][1] += col[from][1];}
}
int main(void)
{
in(dot); R int a, now, fa;
for (R int i = 2; i <= dot; ++i) in(fat[i]);
for (R int i = 1; i <= dot; ++i) in(a), ++col[i][a], bel[i] = i;
for (R int i = 2; i <= dot; ++i) st.insert({i});
W (!st.empty())
{
now = DSU::find(st.begin() -> id);
st.erase(st.begin());
fa = DSU::find(fat[now]); st.erase({fa});
ans += 1ll * col[fa][1] * col[now][0];
DSU::combine(now, fa);
if(fa != 1) st.insert({fa});
}
printf("%lld", ans);
}