Description
给出一个由前kk个大写字母组成的字符串,每次可以花费代价删去该字符串中所有的第ii个大写字母,一个字符串的代价为该字符串每对相邻字符的代价,给出一的矩阵表示两个相邻的大写字母的代价矩阵,问有多少种删除方案使得删除的代价以及剩余字符串的代价之和不超过TT,注意剩余字符串需非空
Input
第一行三个整数表示字符串长度,所用字母种类以及代价上限,之后输入一个由前kk个大写字母组成的长度为的字符串,以及kk个整数表示删去第ii种字符的代价,最后输入一个的矩阵m[x][y]m[x][y]表示相邻两字母的代价矩阵
(1≤n≤2⋅105,1≤k≤22,1≤T≤2⋅109,1≤ti,m[x][y]≤109)(1≤n≤2⋅105,1≤k≤22,1≤T≤2⋅109,1≤ti,m[x][y]≤109)
Output
输出满足条件的方案数
Sample Input
5 3 13
BACAC
4 1 2
1 2 3
2 3 4
3 4 10
Sample Output
5
Solution
用kk位表示每种字母是否需要删除,11表示要删去该种字符
考虑位置的两个字符x,yx,y,假设两个位置之间出现字符的状态为SS,那么删去状态之后代价增加m[x][y]m[x][y],令f[S]f[S]表示SS作为两个位置之间字符状态,删去后由这两个位置字符产生的代价和,那么对于第个位置的字符xx,位置在其后且可以与其产生代价的字符的种类不会超过k−1k−1个(必须是ii位置后第一次出现的字符才可能与xx产生代价),故满足条件的对数不会超过nknk,可以直接暴力找到,两者之间的状态也很好维护,注意SS中若包含字符则不计入代价
令dp[S]dp[S]表示删去SS状态后的代价,只要则为一种合法方案。直观上看,dp[S]dp[S]的值是由若干相邻字符代价构成,而每对相邻字符在原先字符串中可能并不相邻,而是通过删去某个SS的子状态使其相邻的。但直接枚举的子状态S′S′累加f[S′]f[S′]赋予dp[S]dp[S]并不正确,例如对于状态S′S′,产生了一个代价m[x][y],x,y∉S′m[x][y],x,y∉S′,而x∈Sx∈S或y∈Sy∈S,这就导致m[x][y]m[x][y]这一并不应该被记入dp[S]dp[S]的代价通过f[S′]f[S′]记入了dp[S]dp[S],故需要通过以下容斥原理对ff做修正:
f[S|2x]−=m[x][y]f[S|2x]−=m[x][y]
f[S|2y]−=m[x][y]f[S|2y]−=m[x][y]
f[S|2x|2y]+=m[x][y]f[S|2x|2y]+=m[x][y]
之后则有dp[S]=∑S′⊂Sf[S′]dp[S]=∑S′⊂Sf[S′],高维前缀和即可求出dp[S]dp[S],时间复杂度O(n⋅k+k⋅2k)O(n⋅k+k⋅2k)
Code
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<ctime>
using namespace std;
typedef long long ll;
int n,k,T,pre[25],t[25],m[25][25],dp[(1<<22)+5];
char s[200005];
int main()
{
scanf("%d%d%d%s",&n,&k,&T,s);
for(int i=0;i<k;i++)scanf("%d",&t[i]);
for(int i=0;i<k;i++)
for(int j=0;j<k;j++)
scanf("%d",&m[i][j]);
memset(pre,-1,sizeof(pre));
int S=0;
for(int i=0;i<k;i++)dp[1<<i]=t[i];
for(int i=0;i<n;i++)
{
int y=s[i]-'A';
S|=(1<<y);
for(int x=0;x<k;x++)
if(pre[x]>=0)
{
if(!((pre[x]>>x)&1)&&!((pre[x]>>y)&1))
{
dp[pre[x]]+=m[x][y];
dp[pre[x]|(1<<x)]-=m[x][y];
dp[pre[x]|(1<<y)]-=m[x][y];
dp[pre[x]|(1<<x)|(1<<y)]+=m[x][y];
}
pre[x]|=(1<<y);
}
pre[y]=0;
}
int K=1<<k;
for(int i=0;i<k;i++)
for(int j=0;j<K;j++)
if((j>>i)&1)dp[j]+=dp[j^(1<<i)];
int ans=0;
for(int i=0;i<K;i++)
if((i&S)==i&&dp[i]<=T&&i!=S)
ans++;
printf("%d\n",ans);
return 0;
}