题目描述
limpidlimpidlimpid 和 SSS 酱在传输秘密信息,秘密信息可以看成一个数字 xxx。
SSS 酱决定将秘密信息 xxx 编码成一个字符串 SSS。
而 limpidlimpidlimpid 决定解密这个 xxx 是多少。当他知道 SSS 后,他会将其还原成真正的解码串 Sn′S'_nSn′,具体的还原方式为:
其中,aia_iai 表示字符串 SSS 第 iii 个位置上的字符(从 111 开始编号),加号表示拼接运算。
在知道真正的解码串之后,limpidlimpidlimpid 会根据与 SSS 酱之前商定好的 TTT 开始解密,其中 xxx 为 TTT 在 Sn′S'_nSn′ 中以子序列形式出现的次数。
如果你是 limpidlimpidlimpid ,告诉你 SSS,TTT ,你能帮助他解密得到秘密信息 xxx 吗。
由于答案可能很大,你只需要输出 xxx 模 998244353998 244 353998244353 的值即可。
输入
第一行输入两个字符串 SSS, T(1≤∣S∣,∣T∣≤100)T(1 ≤ |S|, |T| ≤ 100)T(1≤∣S∣,∣T∣≤100)。保证两个字符串仅包含小写字母。
输出
输出一个整数表示 xxx 在模 998244353998 244 353998244353 意义下的值。
样例
input
aba ba
output
5
题解
我们根据题目推导 abaabaaba 逐步转化过程如下
S1′=aS'_1=aS1′=a
S2′=abaS'_2=abaS2′=aba
S3′=abaaabaS'_3=abaaabaS3′=abaaaba
就是说在由 Si′S'_iSi′ 到 Si+1′S'_{i+1}Si+1′ 的转变就是前后各放一个 Si′S'_iSi′ ,并在中间加入原字符串 SSS 的第 iii项。
长度为 nnn 的字符串在变化后最终长度会变成 2n2^n2n ,题目中长度为 100100100 ,显然不能直接抠出,也不能通过前缀与后缀利用数学方法来计算。那么我们考虑,这个不断变化的 Si′S'_iSi′ 是可以通过 Si−1′S'_{i-1}Si−1′ 得到的,而且对于子序列的贡献也很容易传递,然后考虑可以存下 Si′S'_iSi′ 对字符串 TTT 一个区间子序列的数量贡献是多少,那么很明显就可以是一道 区间DP区间DP区间DP 的题了。
状态转移方程为: dp[i][l][r]+=dp[i−1][l][k]∗dp[i−1][k+1][r]dp[i][l][r]+= dp[i - 1][l][k] * dp[i - 1][k + 1][r]dp[i][l][r]+=dp[i−1][l][k]∗dp[i−1][k+1][r]
表示 Si′S'_iSi′ 中包含几个 TTT 的从 lll 到 rrr 的子序列。
同时我们考虑如果原字符串 SSS 的第 iii 项和字符串 TTT 的某一项匹配,是不是也会对我们的答案有贡献。
就是说:if(S[i]==T[k+1])dp[i][l][r]+=dp[i−1][l][k]∗dp[i−1][k+2][r]if (S[i] == T[k + 1])dp[i][l][r] += dp[i - 1][l][k] * dp[i - 1][k + 2][r] if(S[i]==T[k+1])dp[i][l][r]+=dp[i−1][l][k]∗dp[i−1][k+2][r]
然后这个题基本上就做出来了,以下是两份代码,为了方便我学会,对原代码的码风进行了修改
代码一
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double dou;
template<typename _T> inline void read(_T &x){
char c=getchar(); bool f=0; x=0;
for(; c<'0'||c>'9'; c=getchar()) f|=(c=='-');
for(; c>='0'&&c<='9'; c=getchar()) x=(x<<1)+(x<<3)+(c^48);
x=(f)?(-x):x;
}
const int N=105, mod=998244353;
inline void add(int &x, int y){x+=y; if(x>=mod) x-=mod;}
char s[N], t[N];
int n, m;
int f[N][N][N];
int main(){
scanf("%s", s+1); scanf("%s", t+1);
n=strlen(s+1); m=strlen(t+1);
for(int i=1; i<=m; ++i)
f[1][i][i]=(t[i]==s[1]);//S_1与字符串T的第i位相等时,f[1][i][i]表示S_1包含1个从T_i
for(int i=2; i<=n; ++i)
{
for(int l=1; l<=m; ++l)
for(int r=l; r<=m; ++r)
add(f[i][l][r], 2ll*f[i-1][l][r]%mod);
//先把S_{i-1}的贡献 * 2 加上
for(int l=1; l<=m; ++l)
for(int r=l+2; r<=m; ++r)
for(int k=l+1; k<r; ++k)
if(s[i]==t[k]) add(f[i][l][r], (ll)f[i-1][l][k-1]*f[i-1][k+1][r]%mod);//状态转移且原字符串S第i位有贡献
for(int l=1; l<=m; ++l)
for(int r=l+1; r<=m; ++r)
for(int k=l+1; k<=r; ++k)
add(f[i][l][r], (ll)f[i-1][l][k-1]*f[i-1][k][r]%mod);//状态转移,但原字符串S的第i位没有贡献
for(int v=1; v<=m; ++v)
if(s[i]==t[v]) add(f[i][v][v], 1);//存下原字符串S第i位与T第v位相同时对 f[i][v][v] 贡献 1
for(int l=2; l<=m; ++l)
for(int r=l; r<=m; ++r)
if(s[i]==t[l-1]) add(f[i][l-1][r], f[i-1][l][r]);//另外加上最左边
for(int l=1; l<m; ++l)
for(int r=l; r<m; ++r)
if(s[i]==t[r+1]) add(f[i][l][r+1], f[i-1][l][r]);//另外加上最右边
}
printf("%d\n", f[n][1][m]);
}
代码二
#include <bits/stdc++.h>
#define int long long
#define mod 998244353
#define N 110
using namespace std;
int dp[N][N][N], n, m;
string S, T;
signed main()
{
cin >> S >> T;
n = S.length(); m = T.length();
S = ' ' + S; T = ' ' + T;
for (int i = 0; i <= n; i ++)
for (int j = 1; j <= m + 1; j ++)
for (int t = 0; t < j; t ++)
dp[i][j][t] = 1;
//匹配空串,在后面的状态转移过程中不需要特意讨论边界
for (int i = 1; i <= n; i++)
for (int l = 1; l <= m; l++)
for (int r = l; r <= m; r++)
{
for (int k = l - 1; k <= r; k ++)
dp[i][l][r] = (dp[i][l][r] + dp[i - 1][l][k] * dp[i - 1][k + 1][r]) % mod;
//dp[i][l][r] 表示 S_i 中能匹配从 T[l] 到 T[r] 这段的子序列次数
//由于 S_i 是由 S_{i - 1} 和 原字符串中第i位的 a[i] 和 S_{i - 1}拼接而成
//则 dp[i][l][r]首先可以通过 dp[i - 1][l][k] * dp[i - 1][k + 1][r] 得到
for (int k = l - 1; k < r; k ++)
if (S[i] == T[k + 1])//当原字符串 S 的第 i 位与字符串 T 的第 k + 1 位相等时
dp[i][l][r] = (dp[i][l][r] + dp[i - 1][l][k] * dp[i - 1][k + 2][r]) % mod;
//dp[i][l][r]加上 dp[i - 1][l][k] * dp[i - 1][k + 2][r]的贡献
}
cout << dp[n][1][m] << '\n';//dp[n][1][m]表示S_n中能匹配从T[1]到T[m]这段的子序列次数
return 0;
}