题意:
给你两个串s1,s2,让你求s2的所有后缀在s1中出现的频率乘以后缀长度的和。
做法:
逆向思维,我们发现如果将s1,s2翻转,s2原来的所有后缀就变成了前缀。
我们知道扩展KMP中extend[i]的含义是文本串后缀i与模式串的最长公共前缀的长度
如果我们对s1,s2使用扩展KMP,求得extend[i],他表示的是s2整个字符串与s1后缀i的最长公共前缀的长度ni,那么s2的各个前缀与s1的后缀i的最长公共前缀的范围在[1,ni]
所以我们遍历整个s1,利用等差数列求和即可。
举个例子:
s1 = “abababab” s2 = “ab”
翻转后:
s2:
b a
s1:
b a b a b a b a
extend:
1 2 1 2 1 2 1 2
sum = 4*(1+2) = 12
AC代码:
#include<bits/stdc++.h>
#define IO ios_base::sync_with_stdio(0),cin.tie(0),cout.tie(0)
#define pb(x) push_back(x)
#define sz(x) (int)(x).size()
#define sc(x) scanf("%d",&x)
#define abs(x) ((x)<0 ? -(x) : x)
#define all(x) x.begin(),x.end()
#define mk(x,y) make_pair(x,y)
#define fin freopen("in.txt","r",stdin)
#define fout freopen("out.txt","w",stdout)
using namespace std;
typedef long long ll;
typedef pair<int,int> PII;
const int mod = 1e9+7;
const double PI = 4*atan(1.0);
const int maxm = 1e8+5;
const int maxn = 1e6+5;
const int INF = 0x3f3f3f3f;
const ll LINF = 1ll<<62;
char t[maxn];
char p[maxn];
int nex[maxn];//模式串所有后缀i与它自己的LCP的长度
int extend[maxn];//文本串所有后缀i与模式串的LCP的长度
void get_nex(char *p,int m) //p模式串
{
nex[0] = m;
int j = 0;
while(j+1<m && p[j] == p[j+1]) j++;
nex[1] = j;
int k = 1;
for(int i=2;i<m;i++){
int pos = nex[k]+k-1;
int l = nex[i-k];
if(i+l < pos+1) nex[i] = l;
else{
j = max(0,pos-i+1);
while(i+j<m && p[i+j] == p[j]) j++;
nex[i] = j;
k = i;
}
}
}
void exkmp(char *t,char *p,int n,int m) //t文本串,p模式串,n文本串长度,m模式串长度
{
get_nex(p,m);
int j = 0;
while(j<n && j<m && t[j] == p[j]) j++;
extend[0] = j;
int k = 0;
for(int i=1;i<n;i++){
int pos = extend[k]+k-1; //pos:以k为起始位置字符匹配的最右边界
int l = nex[i-k];
if(i+l<pos+1) extend[i] = l;
else{
j = max(0,pos-i+1);
while(i+j<n && j<m && t[i+j] == p[j]) j++;
extend[i] = j;
k = i;
}
}
}
int main()
{
// fin;
IO;
int num;
cin>>num;
while(num--)
{
cin>>t>>p;
int m = strlen(p),n = strlen(t);
reverse(t,t+n);
reverse(p,p+m);
exkmp(t,p,n,m);
ll ans = 0;
for(int i=0;i<n;i++)
ans = (ans+1ll*extend[i]*(extend[i]+1)/2)%mod;
printf("%lld\n",ans);
}
return 0;
}