题目大意
给定数列
a
(
0
≤
a
i
≤
n
)
a(0\le a_i\le n)
a(0≤ai≤n),设其排列后的数列为
a
p
1
,
a
p
2
,
.
.
.
,
a
p
n
a_{p_1},a_{p_2},...,a_{p_n}
ap1,ap2,...,apn,要求对于任意的
k
k
k,满足
1
≤
j
≤
k
,
a
p
j
≠
p
k
1\le j\le k,a_{p_j}\neq p_k
1≤j≤k,apj̸=pk。某个合法排列的价值为
w
p
1
+
2
w
p
2
+
.
.
.
+
n
w
p
n
w_{p_1}+2w_{p_2}+...+nw_{p_n}
wp1+2wp2+...+nwpn,求最大价值。
n
≤
5
×
1
0
5
n\le 5\times 10^5
n≤5×105
题解
神题qaq。
首先,好好思考题目中那个乱七八糟的条件是啥,也就是说让
=
p
k
=p_k
=pk的
a
a
a值出现在第
k
k
k个之后。我们考虑最后
p
p
p的排列方式,显然对于任意
a
i
a_i
ai,使
=
a
i
=a_i
=ai的
p
p
p值出现在
=
i
=i
=i的
p
p
p值之前。
于是我们就得到了若干个限制
p
p
p数列的关系,于是对于每个
a
i
a_i
ai,从
a
i
a_i
ai向
i
i
i连一条边,得到了一个图。如果这个图中存在环的话显然无解,否则这个图就是以0为根的树。
考虑
w
w
w最小的那个点,显然如果它的father被选了,下一个选的必然是它。因此它俩必然在数列中连续,于是把它们用并查集合起来,计算贡献。
然而这样操作之后就变成了多个连通块比较,我们如何选择最小的连通块呢?
考虑任意两个连通块
a
,
b
a,b
a,b,如果
a
a
a在
b
b
b之前更优,则必然有
s
i
z
e
[
a
]
⋅
s
u
m
[
b
]
+
s
u
m
[
a
]
+
s
u
m
[
b
]
>
s
u
m
[
a
]
⋅
s
i
z
e
[
b
]
+
s
u
m
[
a
]
+
s
u
m
[
b
]
size[a]\cdot sum[b]+sum[a]+sum[b]>sum[a]\cdot size[b]+sum[a]+sum[b]
size[a]⋅sum[b]+sum[a]+sum[b]>sum[a]⋅size[b]+sum[a]+sum[b],消一下就变成了
s
i
z
e
[
a
]
⋅
s
u
m
[
b
]
>
s
i
z
e
[
b
]
⋅
s
u
m
[
a
]
size[a]\cdot sum[b]>size[b]\cdot sum[a]
size[a]⋅sum[b]>size[b]⋅sum[a]。
于是用优先队列维护最小值就行了,复杂度
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)。
#include <bits/stdc++.h>
namespace IOStream {
const int MAXR = 10000000;
char _READ_[MAXR], _PRINT_[MAXR];
int _READ_POS_, _PRINT_POS_, _READ_LEN_;
inline char readc() {
#ifndef ONLINE_JUDGE
return getchar();
#endif
if (!_READ_POS_) _READ_LEN_ = fread(_READ_, 1, MAXR, stdin);
char c = _READ_[_READ_POS_++];
if (_READ_POS_ == MAXR) _READ_POS_ = 0;
if (_READ_POS_ > _READ_LEN_) return 0;
return c;
}
template<typename T> inline void read(T &x) {
x = 0; register int flag = 1, c;
while (((c = readc()) < '0' || c > '9') && c != '-');
if (c == '-') flag = -1; else x = c - '0';
while ((c = readc()) >= '0' && c <= '9') x = x * 10 - '0' + c;
x *= flag;
}
template<typename T1, typename ...T2> inline void read(T1 &a, T2&... x) {
read(a), read(x...);
}
inline int reads(char *s) {
register int len = 0, c;
while (isspace(c = readc()) || !c);
s[len++] = c;
while (!isspace(c = readc()) && c) s[len++] = c;
s[len] = 0;
return len;
}
inline void ioflush() { fwrite(_PRINT_, 1, _PRINT_POS_, stdout), _PRINT_POS_ = 0; fflush(stdout); }
inline void printc(char c) {
if (!c) return;
_PRINT_[_PRINT_POS_++] = c;
if (_PRINT_POS_ == MAXR) ioflush();
}
inline void prints(const char *s, char c) {
for (int i = 0; s[i]; i++) printc(s[i]);
printc(c);
}
template<typename T> inline void print(T x, char c = '\n') {
if (x < 0) printc('-'), x = -x;
if (x) {
static char sta[20];
register int tp = 0;
for (; x; x /= 10) sta[tp++] = x % 10 + '0';
while (tp > 0) printc(sta[--tp]);
} else printc('0');
printc(c);
}
template<typename T1, typename ...T2> inline void print(T1 x, T2... y) {
print(x, ' '), print(y...);
}
}
using namespace IOStream;
using namespace std;
typedef long long ll;
typedef pair<int, int> P;
const int MAXN = 500005;
struct Node {
ll sum; int sz, rt;
bool operator<(const Node &nd) const { return sum * nd.sz > sz * nd.sum; }
};
priority_queue<Node> pq;
int ww[MAXN], par[MAXN], sz[MAXN], fa[MAXN], n;
ll sum[MAXN];
int find(int x) { return x == par[x] ? x : par[x] = find(par[x]); }
void merge(int x, int y) {
x = find(x), y = find(y);
if (x == y) return;
par[x] = y, sz[y] += sz[x], sum[y] += sum[x];
}
int main() {
read(n);
for (int i = 0; i <= n; i++) par[i] = i;
for (int i = 1; i <= n; i++) {
int t; read(t); fa[i] = t;
if (find(i) == find(t)) return puts("-1") * 0;
par[find(i)] = find(t);
}
ll res = 0;
for (int i = 1; i <= n; i++) {
par[i] = i, sz[i] = 1; read(sum[i]);
pq.push((Node) { sum[i], 1, i });
}
par[0] = 0, sz[0] = 1;
for (int i = 1; i <= n; i++) {
Node d = pq.top(); pq.pop();
if (!d.rt || sz[d.rt] != d.sz) { --i; continue; }
int p = find(fa[d.rt]);
res += sum[d.rt] * sz[p];
merge(d.rt, p);
pq.push((Node) { sum[p], sz[p], p });
}
printf("%lld\n", res);
return 0;
}