题目描述
You are given two strings s and t composed by digits (characters ‘0’ ∼\sim∼ ‘9’). The length of s is n and the length of t is m. The first character of both s and t aren’t ‘0’.
Please calculate the number of valid subsequences of s that are larger than t if viewed as positive integers. A subsequence is valid if and only if its first character is not ‘0’.
Two subsequences are different if they are composed of different locations in the original string. For example, string “1223” has 2 different subsequences “23”.
Because the answer may be huge, please output the answer modulo 998244353.
输入描述:
The first line contains one integer T, indicating that there are T tests.
Each test consists of 3 lines.
The first line of each test contains two integers n and m, denoting the length of strings s and t.
The second line of each test contains the string s.
The third line of each test contains the string t.
-
1≤m≤n≤30001 \le m \le n \le 30001≤m≤n≤3000.
-
sum of n in all tests ≤3000\le 3000≤3000.
-
the first character of both s and t aren’t ‘0’.
输出描述:
For each test, output one integer in a line representing the answer modulo 998244353.
示例1
输入
复制
3
4 2
1234
13
4 2
1034
13
4 1
1111
2
输出
复制
9
6
11
说明
For the last test, there are 6 subsequences “11”, 4 subsequcnes “111” and 1 subsequence “1111” that are valid, so the answer is 11.
题意
给出两个字符串,问你a的子字符串大于b字符串的有多少个。
思路
首先位数大于m的字符串都符合题意,所以我们只需要考虑位数相同的情况。
dp[i][j]表示目前考虑字符串a的第i位,b的第j位,dp维护的是a的子字符串与b串相等有多少种情况。具体看代码。
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define ll long long
using namespace std;
#define maxn 3005
#define mod (998244353)
char a[maxn], b[maxn];
ll dp[maxn][maxn], c[maxn][maxn];
void pre(){
int i, j;
c[0][0] = 1;
for(i = 1; i <= 3000; i++){
c[i][0] = 1;
for(j = 1; j <= i; j++){
c[i][j] = (c[i-1][j-1]+c[i-1][j])%mod;
}
}
}
int main(){
int i, j, k, n, m, T;
for(i = 0; i <= 3000; i++) dp[i][0] = 1;
pre();
scanf("%d",&T);
while(T--){
scanf("%d%d",&n,&m);
scanf("%s",a+1);
scanf("%s",b+1);
ll ans = 0;
for(i = 1; i <= n; i++){
int lim = min(m, i);
for(j = 1; j <= lim; j++){
dp[i][j] = dp[i-1][j];
if(a[i] == b[j]) dp[i][j] = (dp[i][j] + dp[i-1][j-1])%mod;
if(a[i] > b[j]) ans = (ans + c[n-i][m-j]*dp[i-1][j-1])%mod;
}
}
for(i = 1; i <= n-m; i++){
if(a[i]=='0') continue;
for(j = m; j <= n-i ;j++) ans = (ans + c[n-i][j])%mod;
}
printf("%lld\n",ans);
}
}