思路来自:https://blog.youkuaiyun.com/jk_chen_acmer/article/details/99092788
题意:
长度为L的路,每一步至少走d步,求从0恰好走到L的方案数。限制条件,有m次attack,每一次attack会在ti的时间攻击pi,也就是第ti步不能走到pi位置。
思路:
先考虑没有限制条件的情况。令
f
[
i
]
f[i]
f[i]表示到第i个位置的方案数。则
f
[
i
]
=
f
[
i
−
d
]
+
f
[
i
−
d
−
1
]
+
.
.
.
+
f
[
0
]
f[i] = f[i-d]+f[i-d-1]+...+f[0]
f[i]=f[i−d]+f[i−d−1]+...+f[0]
使用前缀和能够做到
O
(
L
)
O(L)
O(L)复杂度内实现。
接下来考虑有限制条件的情况。即是在
f
[
L
]
f[L]
f[L]的前提下去除不合法的方案数。
将所有attack按
p
[
i
]
p[i]
p[i]排序。所有不合法的即是,对于每一个
i
i
i,
[
0
,
p
i
)
[0,p_i)
[0,pi)区间内不经过attack的方案数
∗
[
p
i
,
L
]
*[p_i,L]
∗[pi,L]区间内的所有方案数,再将所有的累加。按顺序枚举了是否包含第i个的所有情况。这样能够保证不重不漏枚举所有非法情况。
具体求法是:
令
g
[
i
]
g[i]
g[i]表示到第
i
i
i个attack,前
[
0
,
p
i
)
[0,p_i)
[0,pi)内不经过attack,到
p
i
p_i
pi时恰好经过attack的方案数。计
c
a
l
c
(
a
,
b
)
calc(a,b)
calc(a,b)表示距离为
a
a
a,经过
b
b
b步的所有可能的方案数。
g
[
i
]
=
c
a
l
c
(
p
i
,
t
i
)
∗
f
[
L
−
p
i
]
−
∑
j
=
1
i
−
1
g
[
j
]
∗
c
a
l
c
(
p
i
−
p
j
,
t
i
−
t
j
)
g[i] = calc(p_i,t_i)*f[L-p_i] - \sum_{j=1}^{i-1}g[j]*calc(p_i-p_j,t_i-t_j)
g[i]=calc(pi,ti)∗f[L−pi]−j=1∑i−1g[j]∗calc(pi−pj,ti−tj)
最终的方案数就是
f
[
L
]
−
∑
i
=
1
m
g
[
i
]
∗
f
[
L
−
p
i
]
f[L]-\sum_{i=1}^{m}g[i]*f[L-p_i]
f[L]−i=1∑mg[i]∗f[L−pi]
c
a
l
c
(
a
,
b
)
calc(a,b)
calc(a,b)函数的求法。可以理解为
a
a
a个相同的小球,放入
b
b
b个不相同的盒子里,每一个盒子至少放
d
d
d个。考虑“隔板法”,我们不妨先往
b
b
b个盒子里各放入
d
−
1
d-1
d−1个小球,转化成了
a
−
(
d
−
1
)
∗
b
a-(d-1)*b
a−(d−1)∗b个相同的小球,放入
b
b
b个盒子里,不允许空盒的情况。所以最终答案即是
C
a
−
(
d
−
1
)
∗
b
−
1
b
−
1
C_{a-(d-1)*b-1}^{b-1}
Ca−(d−1)∗b−1b−1
代码:
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ld long double
#define ull unsigned long long
#define __ ios::sync_with_stdio(0);cin.tie(0);cout.tie(0)
const int maxn = 1e7 + 10;
const ll mod = 998244353;
const ld PI = acos(-1.0);
ll f[maxn], s[maxn], jie[maxn];
ll L, d, m;
ll g[3030];
struct pxy {
int t, p;
bool operator<(const pxy &b) const {
return p < b.p;
}
} a[3030];
ll pow_mod(ll a, ll b, ll m) {
ll ans = 1;
while (b) {
if (b & 1)ans = ans * a % m;
a = a * a % m;
b >>= 1;
}
return ans;
}
inline ll ni(ll x) {
return pow_mod(jie[x], mod - 2, mod);
}
inline ll C(ll n, ll m) {
if (m > n) return 0;
return ((jie[n] * ni(m)) % mod) * ni(n - m) % mod;
}
ll calc(ll a, ll b) {
a -= b * (d - 1);
return C(a - 1, b - 1);
}
int main() {
__;
cin >> L >> d >> m;
jie[0] = 1;
for (ll i = 1; i <= L; ++i) {
jie[i] = (jie[i - 1] * i) % mod;
}
for (int i = 1; i <= m; ++i)cin >> a[i].t >> a[i].p;
sort(a + 1, a + 1 + m);
f[0] = 1;
s[0] = 1;
for (int i = 1; i <= L; ++i) {
if (i - d >= 0) {
f[i] = s[i - d];
}
s[i] = (s[i - 1] + f[i]) % mod;
}
for (int i = 1; i <= m; ++i) {
g[i] = calc(a[i].p, a[i].t);
for (int j = i - 1; j >= 1; --j) {
if (a[j].t < a[i].t) {
g[i] = (g[i] - (calc(a[i].p - a[j].p, a[i].t - a[j].t) * g[j]) % mod + mod) % mod;
}
}
}
ll ans = f[L];
for (int i = 1; i <= m; ++i) {
ans = (ans - (f[L - a[i].p] * g[i]) % mod + mod) % mod;
}
cout << ans << endl;
return 0;
}