发现 2 ≤ k ≤ 8 2 \leq k \leq 8 2≤k≤8,于是可以考虑状压 DP。
我们可以设 f [ l ] [ r ] [ S ] f[l][r][S] f[l][r][S] 表示区间 [ l , r ] [l, r] [l,r] 经过若干次操作变成 01 串 S S S 的最大分数。
基于贪心的思想,我们一定要使得整个 01 串最终不可操作,所以 01 串 S S S 的长度不超过 k − 1 k - 1 k−1,否则 S S S 可以继续进行合并,直到其长度小于等于 k − 1 k - 1 k−1。
那么我们有一个很暴力的做法:
对于 l , r l, r l,r,枚举中点 m i d mid mid,再枚举区间 [ l , m i d ] [l, mid] [l,mid] 和 [ m i d + 1 , r ] [mid + 1, r] [mid+1,r] 的所有状态,同时枚举合并的区间 [ p , p + k − 1 ] [p, p + k - 1] [p,p+k−1],直接转移即可。
这样做的时间复杂度是 O ( k n 3 2 2 k − 2 ) O(kn^32^{2k - 2}) O(kn322k−2) 的,会 T 飞。
这个时候,一般的想法是找到一些有利于统计的性质。
我们会发现,一个 p p p 可能会被多个区间枚举到,但是它们转移的效果是一样的。
举个例子:
假设我们有两个区间 [ l 1 , r 1 ] , [ l 2 , r 2 ] [l_1, r_1], [l_2, r_2] [l1,r1],[l2,r2]。
前者枚举出的状态分别为: 00010 , 10100 00010, 10100 00010,10100;
后者枚举出的状态分别为: 00101 , 01000 00101, 01000 00101,01000;
如果前者取区间 [ 4 , 8 ] [4, 8] [4,8] 合并,后者取区间 [ 3 , 7 ] [3, 7] [3,7] 合并,那么上面的两个区间合并出来是等价的。
所以我们遵从最简化原则,对于每个区间,我们钦定后面一部分区间(即 [ m i d + 1 , r ] [mid + 1, r] [mid+1,r] 这一部分)的状态只能为 0 或 1(即长度为 1 的 01 串)。实际上,这样的操作相当于不断往一个串中添 0 或 1。
接着,我们可以发现,长度为 l e n len len 的串,最终一定会被操作成长度为 ( l e n − 1 ) m o d ( k − 1 ) + 1 (len - 1)\bmod (k - 1) + 1 (len−1)mod(k−1)+1 的串。也就是说,只要最初的 01 串的长度是确定的,那么最终的 01 串长度同样是确定的。
所以,如果一段区间的长度 l e n len len 满足 ( l e n − 1 ) m o d ( k − 1 ) = 0 (len - 1)\bmod (k - 1) = 0 (len−1)mod(k−1)=0,那么经过若干次操作后,这个区间一定可以变成长度为 1 的 01 串。
因此,枚举 m i d mid mid 时,我们不需要一个个位置地枚举,只需要满足 ( r − m i d − 1 ) m o d ( k − 1 ) = 0 (r - mid - 1) \bmod (k - 1) = 0 (r−mid−1)mod(k−1)=0 即可。
(所以,DP 状态中的 S S S 实际上可能带有若干个前导 0,取决于最终变成的 01 串的长度)
综合上述性质,我们就有了时间复杂度是 O ( n 2 ⌊ n k ⌋ 2 k − 1 ) O(n^2\lfloor\frac{n}{k}\rfloor 2^{k - 1}) O(n2⌊kn⌋2k−1) 的做法,不过这只是一个上界,并且区间 DP 本身就是跑不满的,所以这个复杂度是可以通过的。
但是,上述做法就不能在转移的过程中计算答案了。
不过,结合前面的这句 “实际上,这样的操作相当于不断往一个串中添 0 或 1”,我们可以发现,当 ( l e n − 1 ) m o d ( k − 1 ) = 0 (len - 1) \bmod (k-1) = 0 (len−1)mod(k−1)=0 时,01 串中正好添加了 k k k 个字符。此时,我们就可以把它们再合并一次,同时计算答案了。
代码:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cassert>
#include<cstring>
#define MAXN 305
#define MAXM 8
using LL = long long;
using namespace std;
int n, k, a[MAXN], c[1 << MAXM];
LL ans, w[1 << MAXM], f[MAXN][MAXN][1 << MAXM];
const LL INF = 0x3F3F3F3F3F3F3F3F;
int main()
{
#ifdef FILE
freopen("Input.in", "r", stdin);
freopen("Output.out", "w", stdout);
#endif
scanf("%d%d", &n, &k);
for(register int i = 1; i <= n; i++) scanf("%d", a + i);
for(register int i = 0; i < (1 << k); i++) scanf("%d%lld", c + i, w + i);
for(register int i = 1; i <= n; i++) for(register int j = 1; j <= n; j++) for(register int S = 0; S < (1 << k); S++) f[i][j][S] = -INF;
for(register int i = 1; i <= n; i++) f[i][i][a[i]] = 0;
for(register int len = 2; len <= n; len++)
{
for(register int l = 1, r = l + len - 1; r <= n; ++l, ++r)
{
int t = (len - 1) % (k - 1);
if(t == 0) t = k - 1;
for(register int mid = r - 1; mid >= l; mid -= k - 1)
{
for(register int S = 0; S < (1 << t); S++)
{
f[l][r][S << 1] = max(f[l][r][S << 1], f[l][mid][S] + f[mid + 1][r][0]); //添加一个 0
f[l][r][S << 1 | 1] = max(f[l][r][S << 1 | 1], f[l][mid][S] + f[mid + 1][r][1]); //添加一个 1
}
}
if(t == k - 1) //再次合并
{
LL g[2] = {-1LL, -1LL};
for(register int S = 0; S < (1 << k); S++) g[c[S]] = max(g[c[S]], f[l][r][S] + w[S]);
f[l][r][0] = g[0], f[l][r][1] = g[1];
}
}
}
ans = -INF;
for(register int i = 0; i < (1 << (k - 1)); i++) ans = max(ans, f[1][n][i]);
printf("%lld\n", ans);
return 0;
}