【题目链接】
【思路要点】
- 我们显然可以得到一个直接的DP做法,记\(F_{i,j}\)表示在前\(i\)个集合中取\(j\)个元素的方案数,则显然有$$ \left\{\begin{aligned}F_{0,0} = 1\\F_{i,j}=\sum_{k=max(0,j-M_{i})}^{j}F_{i-1,k}\end{aligned}\right.$$
- 运用滚动数组,对转移稍加优化,可以得到一个空间复杂度为\(O(B)\),时间复杂度为\(O(N*B)\)的做法,但本题时间限制较紧,这种方法不能通过。
- 若每个集合中元素个数没有上界,那么问题就成为了一个经典的组合数学问题,假设共有\(N\)个集合,要选出\(M\)个元素,显然方案数为\(\dbinom{M+N-1}{N-1}\)。
- 本题中存在上界,不容易处理,但我们注意到若取的元素个数只存在下界是容易处理的,只需将\(A\)和\(B\)均减去下界即可。考虑利用容斥原理将上界化为下界。我们称在集合\(i\)中取超过\(M_{i}\)个元素为“不符合集合\(i\)的限制”,则由容斥原理,可得$$Ans=\sum_{i=0}^{N}(-1)^{i}*Total_{i}(其中Total_{i}表示不符合至少i个集合的限制的方案数)$$
- 那么,我们枚举每个集合是否一定不符合限制,即可将问题转化为只有下界的取球问题。
- 接下来唯一的问题在于如何求解\(\dbinom{M}{N} Mod\ \ 2004\)。
- 令\(\dbinom{M}{N} Mod\ \ 2004=X\),也即\(\frac{M!}{N!(M-N)!}=X+2004K(K\in Z)\)。
- 在等式两侧同时乘以\(N!\)得到\(\frac{M!}{(M-N)!}=X*N!+2004N!*K(K\in Z)\),也即\(A_{M}^{N}\ \ Mod\ \ 2004N!=X*N!\)。
- 所以,将模数乘以\(N!\)计算,在最后将答案除以\(N!\)即可,此时模数最大为\(7272115200\),而乘数总在\(10^{7}\)以内,所以使用64位整型即可通过本题。
- 时间复杂度\(O(2^{N}*N)\)。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 15; const int TP = 2004; template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } int n, l, r, a[MAXN]; long long P, ans; void check(int l, int r, int f) { long long now = 1, tmp = 1; for (int i = 1; i <= n; i++) { now = now * (r + n - i + 1) % P; tmp = tmp * (l + n - i) % P; } if (f == 1) ans += now - tmp; else ans += tmp - now; } void work(int pos, int l, int r, int f) { if (r == -1) return; if (pos <= n) { work(pos + 1, l, r, f); work(pos + 1, max(l - a[pos] - 1, 0), max(r - a[pos] - 1, -1), -f); } else check(l, r, f); } int main() { read(n), read(l), read(r); for (int i = 1; i <= n; i++) read(a[i]); P = TP; for (int i = 1; i <= n; i++) P *= i; work(1, l, r, 1); ans = (ans % P + P) % P; for (int i = 1; i <= n; i++) ans /= i; printf("%lld\n", ans); return 0; }