ARC121E Directed Tree
一个点如果想要合法,首先不能和子树中的点重复,其次不能填子树中的点的编号 —— 这会有后效性,使得我们非常困难地处理 DP
转移。但是如果我们考虑不合法,十分简单 —— 只需要填入子树中的其中一个,并且不和子树中不合法的点填同样的数字即可。
正难则反,看成有多个限制形如 i i i 不能在 i i i 的子树除自己中出现,考虑容斥,记 F ( i ) F(i) F(i) 表示有 i i i 个不合法点,其他不确定的方案数,答案即为 ∑ i = 0 n ( − 1 ) i × F ( i ) \sum\limits_{i=0}^{n}(-1)^i\times F(i) i=0∑n(−1)i×F(i)。
考虑 DP
,设
f
u
,
i
f_{u,i}
fu,i 表示
u
u
u 的子树内有
i
i
i 个不合法点的情况,转移时先统计儿子的方案数。
f
u
,
a
+
b
←
f
u
,
a
+
b
+
f
u
,
a
′
×
f
v
,
b
f_{u,a+b}\leftarrow f_{u,a+b}+f'_{u,a}\times f_{v,b}
fu,a+b←fu,a+b+fu,a′×fv,b
再考虑转移
u
u
u 点的方案数,
u
u
u 能填的不合法位置有
s
i
z
u
−
1
siz_u-1
sizu−1 个,目前已经填了
i
−
1
i-1
i−1 个,有
s
i
z
u
−
i
siz_u-i
sizu−i 个能填的位置,
i
i
i 从大到小更新不重复。
f
u
,
i
←
f
u
,
i
+
(
s
i
z
u
−
i
)
f
u
,
i
−
1
f_{u,i}\leftarrow f_{u,i}+(siz_u-i)f_{u,i-1}
fu,i←fu,i+(sizu−i)fu,i−1
最后答案显然是 ∑ i = 0 n ( − 1 ) i × f 1 , i × ( n − i ) ! \sum\limits_{i=0}^{n}(-1)^i \times f_{1,i} \times (n-i)! i=0∑n(−1)i×f1,i×(n−i)!,其中 f 1 , 0 = 1 f_{1,0}=1 f1,0=1。
时间复杂度 O ( n 2 ) \mathcal O(n^2) O(n2)。
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define he putchar('\n')
#define ha putchar(' ')
typedef long long ll;
inline int read()
{
int x = 0, f = 1;
char c = getchar();
while(c < '0' || c > '9')
{
if(c == '-') f = -1;
c = getchar();
}
while(c >= '0' && c <= '9')
x = x * 10 + c - 48, c = getchar();
return x * f;
}
inline void write(int x)
{
if(x < 0)
{
x = -x;
putchar('-');
}
if(x > 9) write(x / 10);
putchar(x % 10 + 48);
}
const int _ = 2010, mod = 998244353;
int n, ans, fac[_], f[_][_], siz[_], t[_];
vector<int> d[_];
void dfs(int u, int fa)
{
f[u][0] = 1;
for(int v : d[u])
{
if(v == fa) continue;
dfs(v, u);
for(int j = 0; j <= siz[u]; ++j)
for(int k = 0; k <= siz[v]; ++k)
t[j + k] = (t[j + k] + f[u][j] * f[v][k] % mod) % mod;
siz[u] += siz[v];
for(int j = 0; j <= siz[u]; ++j) f[u][j] = t[j], t[j] = 0;
}
++siz[u];
for(int i = siz[u] - 1; i >= 0; --i)
f[u][i + 1] = (f[u][i + 1] + (siz[u] - i - 1) * f[u][i] % mod) % mod;
}
signed main()
{
n = read();
for(int i = 2, x; i <= n; ++i)
{
x = read();
d[i].push_back(x), d[x].push_back(i);
}
dfs(1, 0);
fac[0] = 1;
for(int i = 1; i <= n; ++i) fac[i] = 1ll * fac[i - 1] * i % mod;
for(int i = 0; i <= n; ++i) ans = (ans + (i & 1 ? -1 : 1) * f[1][i] * fac[n - i] % mod + mod) % mod;//, cout << ans << "!!!\n";
write((ans % mod + mod) % mod), he;
return 0;
}