转自:https://www.cnblogs.com/zjp-shadow/p/9773420.html
这个代码太6了
#include <bits/stdc++.h>
#define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << (x) << endl
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)
using namespace std;
typedef long long ll;
template<typename T> inline bool chkmin(T &a, T b) {return b < a ? a = b, 1 : 0;}
template<typename T> inline bool chkmax(T &a, T b) {return b > a ? a = b, 1 : 0;}
inline int read() {
int x(0), sgn(1); char ch(getchar());
for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
for (; isdigit(ch); ch = getchar()) x = (x * 10) + (ch ^ 48);
return x * sgn;
}
void File() {
#ifdef zjp_shadow
freopen ("C.in", "r", stdin);
freopen ("C.out", "w", stdout);
#endif
}
const int N = 1e6 + 1e3;
bitset<N> pass;
ll sum[N], dp[N]; int n, fa[N];
int main () {
File();
n = read();
For (i, 1, n) sum[i] = read();
For (i, 2, n) fa[i] = read();
Fordown (i, n, 1) sum[fa[i]] += sum[i];
// tmp == k, 最多可以分成k组
// dp保存符合 sum[1] / __gcd(sum[1], sum[i]) = k 的子树i的总数
For (i, 1, n) {
ll tmp = sum[1] / __gcd(sum[1], sum[i]);
if (tmp <= n) ++ dp[tmp];
}
// dp[k]保存符合 sum[1] / __gcd(sum[1], sum[i]) = n * k 的子树i的总数
Fordown (i, n, 1) if (dp[i])
for (int j = i * 2; j <= n; j += i) dp[j] += dp[i];
// pass[i] 表示子树是否存在分成 i 个联通块的方案
For (i, 1, n)
pass[i] = (dp[i] == i && !(sum[1] % i)), dp[i] = 0;
dp[1] = pass[1];
// dp[j]保存有多少种方法可以从i个联通块,只分一级,分成j个联通块
// 因为分出来的新的一层所有联通块和要一样,所以j必须是i的倍数,画图可得
ll ans = 0;
For (i, 1, n) if (pass[i]) {
for (int j = i * 2; j <= n; j += i)
if (pass[j]) dp[j] += dp[i];
ans += dp[i];
}
printf ("%lld\n", ans);
return 0;
}