代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int MX=2e5+9;
char s[MX],p[MX];
int exnext[MX],exten[MX],n,m;
void getexnext(){
m=strlen(p);
exnext[0]=m;
int i=0;
while( p[i]==p[i+1] && i<m-1 )
i++;
exnext[1]=i;
int k=i+1,po=1;
for( int i=2; i<m ; i++ ){
if( exnext[i-po]<k-i )
exnext[i]=exnext[i-po];
else{
int j=k-i;
if( j<=0 )
j=0;
while( p[j]==p[j+i] && i+j<m && j<m )
j++;
exnext[i]=j;
k=i+j;
po=i;
}
}
return ;
}
void exkmp(){
getexnext();
n=strlen(s);
int i=0;
while( s[i]==p[i] && i<m && i<n )
i++;
exten[0]=i;
int k=i,po=0;
for( int i=1 ; i<n ; i++ ){
if( exnext[i-po]<k-i )
exten[i]=exnext[po-i];
else{
int j=k-i;
if( j<=0 )
j=0;
while( p[j]==s[i+j] && i+j<n && j<m )
j++;
exten[i]=j;
k=i+j;
po=i;
}
}
return ;
}
int main()
{
//freopen("input.txt","r",stdin);
scanf("%s %s",s,p);
exkmp();
ll ans=0;
for( int i=0 ; i<n ; i++ ){
ans+=min(exten[i],m-1);
if( s[i+exten[i]]<p[exten[i]] )
ans+=(n-i-exten[i]);
}
printf("%lld\n",ans);
return 0;
}