题目链接:https://codeforces.com/gym/101741/problem/J
题目大意:
你有一个数组a含n个整数,还有一个模数m,你要处理q询问,每个询问给你一段区间[l, r],问你这段区间中有多少个子序列,使得子序列之和模m等于零。注意,空集也算子序列。答案模1e9 + 7。
解题思路:
针对一次询问,我们很容易能得到一个状态转移方程,令dp[i][j] 为区间[x, i](起点未知)内 子序列之和模m等于j 的方案数,则有
dp[i][j] += dp[i][j] + dp[i-1][j] + dp[i-1][(j-a[i]+m) %m]。问题是如果对每次询问都跑一边这个类似01背包的东西,妥妥超时。所以我们想着能不能用分治做。如果我们可以得到[l, mid]的从右往左计算的方案数和[mid+1, r]从左往右计算的方案数,那么如果询问的区间在[l, r]之内且横跨mid,令查询区间为[x, y],答案将为(dp1[x][0] * dp2[y][0]) + sigma(0<i<m) dp1[x][i]*dp[y][m-i]。如果查询不横跨mid,照这样分治下去即可。注意分治不处理单点情况,要特判。
代码如下:
# include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 1e9 + 7;
const ll maxn = 2e5 + 5;
ll lv[maxn][21], rv[maxn][21]; //当然也可以只用一个数组
ll ans[maxn];
ll q[maxn][2];
ll a[maxn];
ll n, m, qn;
vector <ll> id;
void divide(ll l, ll r, vector<ll> id){
ll mid = (l + r) / 2;
for(ll i = mid+1; i >= l; --i)
for(ll j = 0; j < m; ++j)
lv[i][j] = 0;
lv[mid+1][0] = 1;
for(ll i = mid; i <= r; ++i)
for(ll j = 0; j < m; ++j)
rv[i][j] = 0;
rv[mid][0] = 1;
for(ll i = mid; i >= l; --i) //lv范围是[l, mid]
for(ll j = m-1; j >= 0; --j)
lv[i][j] = (lv[i][j] + lv[i+1][j]) % mod + lv[i+1][(j-a[i]+m) %m] % mod;
for(ll i = mid+1; i <= r; ++i) //rv范围是[mid+1, r]
for(ll j = m-1; j >= 0; --j)
rv[i][j] = (rv[i][j] + rv[i-1][j]) % mod + rv[i-1][(j-a[i]+m) %m] % mod;
vector <ll> id_l, id_r;
for(ll i = 0; i < id.size(); ++i){
ll l = q[id[i]][0], r = q[id[i]][1], _id = id[i]; //注意这里的l,r是询问区间的,不是函数形参
if(l <= mid && r > mid){
ans[_id] = (lv[l][0] * rv[r][0]) % mod;
for(ll j = 1; j < m; ++j){
ans[_id] += lv[l][j] * rv[r][m-j];
ans[_id] %= mod;
}
}
else if(r <= mid) id_l.push_back(_id);
else id_r.push_back(_id);
}
if(id_l.size() && l < mid) divide(l, mid, id_l);
if(id_r.size() && r > mid+1) divide(mid+1, r, id_r);
//如果区间[l,r]只有两个点就没必要分治下去了
}
int main(){
std::ios::sync_with_stdio(false);
while(cin >> n >> m){
for(ll i = 1; i <= n; ++i){
cin >> a[i];
a[i] %= m;
}
cin >> qn;
for(ll i = 1; i <= qn; ++i){
cin >> q[i][0] >> q[i][1];
if(q[i][0] == q[i][1]){
if(a[ q[i][0] ] > 0) ans[i] = 1;
else ans[i] = 2;
} else {
id.push_back(i);
}
}
divide(1, n, id);
for(ll i = 1; i <= qn; ++i)
cout << ans[i] << endl;
}
return 0;
}