Step1 Problem:
给你字符串 s1, s2. 求 s2 的每个后缀在 s1 中出现的次数 * 后缀长度 和,结果取模1e9+7.
例:
输入:
abababab
aba
输出:
19
Step2 Ideas:
我们将问题转换为,求 s2 的每个前缀在 s1 中出现的次数 * 前缀长度 和。
这样我们需要反转s1, s2.
KMP做法:
个人习惯从0开始,next[0] = 0;
next[i] : 代表下标从 0 到 i 这个子串,后缀 = 前缀 最长长度(不包括自身)。
反转s1, s2后:
直接去比较求各个前缀出现的次数,由于 next[] 的含义,当 j = next[j-1],长度为 next[j-1] 的前缀没能加1,此时该前缀应该在 s1 中又多出现了一次。
所以长度为 len 的前缀出现了几次,长度为 next[len-1] 的前缀需要补加几次。扩展KMP做法:
next[i]:i 位置开始的后缀串和原串的最长公共前缀长度。
extend[i]: i 位置开始的后缀串 和 另一个串(也就是求next[]的串) 的 最长公共前缀长度。
反转s1, s2后:
extend[i] = 3, s2 串长度为 3, 2, 1 的前缀在 i 位置各出现了一次。
Step3 Code:
//KMP做法
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 1000055
#define mod 1000000007
int nex[N];
long long num[N];
char s1[N], s2[N];
void get_nex(char s2[])
{
int i, j = 0, len = strlen(s2);
nex[j] = 0;
for(i = 1; i < len; i++)
{
j = nex[i - 1];
while(j && s2[j] != s2[i])
j = nex[j - 1];
if(s2[j] == s2[i]) nex[i] = ++j;
else nex[i] = 0;
}
}
void KMP(char s1[], char s2[])
{
int i = 0, j = 0, len1 = strlen(s1), len2 = strlen(s2);
while(i < len1)
{
if(s1[i] == s2[j])//相等
{
num[j+1]++;i++; j++;//s2长度为j+1的前缀在s1中出现的次数++
}
else
{
if(j == 0) i++;
else j = nex[j - 1];
}
if(j >= len2)
j = nex[j - 1];//s2长度为nex[j-1]的前缀 会少加
}
}
int main()
{
int T, i, len1, len2;
scanf("%d", &T);
while(T--)
{
scanf("%s %s", s1, s2);
len1 = strlen(s1), len2 = strlen(s2);
for(i = 0; i < len1 / 2; i++)
swap(s1[i], s1[len1 - i - 1]);
for(i = 0; i < len2 / 2; i++)
swap(s2[i], s2[len2 - i - 1]);
memset(num, 0, sizeof(num));
get_nex(s2); KMP(s1, s2);
long long ans = 0;
for(i = len2; i >= 1; i--)
{
ans += (num[i]*i)%mod, ans %= mod;
num[nex[i-1]] += num[i];//补加 匹配过程中 没加的次数
}
printf("%lld\n", ans);//输出
}
return 0;
}
//扩展KMP做法
#include<bits/stdc++.h>
using namespace std;
const int MOD = 1e9+7;
const int N = 1e6+5;
#define ll long long
char s1[N], s2[N];
int nex[N], extend[N];
ll vis[N];
void get_next(char s[])
{
int len = strlen(s);
nex[0] = len;
int mx = 0, id;
for(int i = 1; i < len; i++)
{
if(i < mx) nex[i] = min(mx-i, nex[i-id]);
else nex[i] = 0;
while(s[i+nex[i]] == s[nex[i]]) nex[i]++;
if(mx < i+nex[i]) {
id = i;
mx = i+nex[i];
}
}
}
void get_extend()
{
int len = strlen(s1);
int mx = 0, id;
for(int i = 0; i < len; i++)
{
if(i < mx) extend[i] = min(mx-i, nex[i-id]);
else extend[i] = 0;
while(s1[i+extend[i]] == s2[extend[i]] && s2[extend[i]] != '\0') extend[i]++;
if(mx < extend[i]+i) {
id = i;
mx = extend[i]+i;
}
}
}
int main()
{
int T;
scanf("%d", &T);
while(T--)
{
scanf("%s %s", s1, s2);
int len1 = strlen(s1), len2 = strlen(s2);
for(int i = 0; i < len1/2; i++)
swap(s1[i], s1[len1-i-1]);
for(int i = 0; i < len2/2; i++)
swap(s2[i], s2[len2-i-1]);
get_next(s2);
get_extend();
memset(vis, 0, sizeof(vis));
for(int i = 0; i < len1; i++)//长度extend[i]~1的前缀 在i位置出现
vis[extend[i]]++;
ll ans = (vis[len2]*len2)%MOD;
for(int i = len2-1; i >= 1; i--)
{
vis[i] += vis[i+1], vis[i] %= MOD;
ans += (vis[i]*i)%MOD, ans %= MOD;
}
printf("%lld\n", ans);
}
return 0;
}