题目大意
两个字符串a和b,从a中选取一个非空子串,从b中选取一个非空子串,然后拼起来变成回文串,求这个回文串最大长度。
SAM
先做一遍manacher处理出从某个串某个位置出发最长回文串。
先把b反过来。
然后你注意到这个回文串一定可以表示成c+S+c,其中S是一个回文串,然后c是a或b中选出的子串。
我的做法是把a和b一起丢去做广义SAM,然后计算每个节点是否同时出现在两个串中,并维护right集那个处理出的manacher的最大值,然后就可以算了。
#include<cstdio>
#include<algorithm>
#include<cstring>
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define max(a,b) (a>b?a:b)
using namespace std;
const int maxn=800000+10;
int g[maxn*4][26],step[maxn*4],f[maxn*4],fail[maxn*4],mxa[maxn*4],mxb[maxn*4],a[maxn*4],sum[maxn*4];
int c[maxn],fa[maxn*2],fb[maxn*2],sa[maxn],sb[maxn];
char s1[maxn*2],s2[maxn*2];
int i,j,k,l,t,n,m,tot,top,ca,ans,last,mx;
bool czy;
int newnode(){
int i;
tot++;
if (tot<=mx){
f[tot]=0;
mxa[tot]=mxb[tot]=0;
fo(i,0,25) g[tot][i]=0;
}
return tot;
}
void add(int c){
int p=last,i;
if (g[p][c]){
int q=g[p][c];
if (step[q]==step[p]+1) last=q;
else{
int nq=newnode();
fo(i,0,25) g[nq][i]=g[q][i];
fail[nq]=fail[q];
step[nq]=step[p]+1;
fail[q]=nq;
while (p&&g[p][c]==q){
g[p][c]=nq;
p=fail[p];
}
last=nq;
}
return;
}
int np=newnode();
step[np]=step[p]+1;
while (p&&g[p][c]==0){
g[p][c]=np;
p=fail[p];
}
if (p==0) fail[np]=1;
else{
int q=g[p][c];
if (step[q]==step[p]+1) fail[np]=q;
else{
int nq=newnode();
fo(i,0,25) g[nq][i]=g[q][i];
fail[nq]=fail[q];
step[nq]=step[p]+1;
fail[q]=nq;
fail[np]=nq;
while (p&&g[p][c]==q){
g[p][c]=nq;
p=fail[p];
}
}
}
last=np;
}
int main(){
scanf("%d",&ca);
while (ca--){
scanf("%s",s1+1);
scanf("%s",s2+1);
n=strlen(s1+1);
reverse(s2+1,s2+n+1);
fd(i,n,1) s1[i*2]=s1[i],s1[i*2-1]='#';
s1[n*2+1]='#';
s1[n*2+2]='%';
s1[0]='$';
fa[1]=0;
j=1;
ans=-1;
fo(i,2,n*2){
fa[i]=0;
if (i<=j+fa[j]) fa[i]=min(fa[j*2-i],j+fa[j]-i);
while (s1[i+fa[i]+1]==s1[i-fa[i]-1]) fa[i]++;
if (i+fa[i]>=j+fa[j]) j=i;
//ans=max(ans,fa[i]);
}
fo(i,1,n) c[i]=0;
fo(i,2,n*2){
l=i-fa[i];
if (l%2==1) l++;
l/=2;
c[l]=max(c[l],fa[i]);
}
k=0;
fo(i,1,n){
if (k) k-=2;
k=max(k,c[i]);
sa[i]=k;
}
fd(i,n,1) s2[i*2]=s2[i],s2[i*2-1]='#';
s2[n*2+1]='#';
s2[n*2+2]='%';
s2[0]='$';
fb[1]=0;
j=1;
fo(i,2,n*2){
fb[i]=0;
if (i<=j+fb[j]) fb[i]=min(fb[j*2-i],j+fb[j]-i);
while (s2[i+fb[i]+1]==s2[i-fb[i]-1]) fb[i]++;
if (i+fb[i]>=j+fb[j]) j=i;
//ans=max(ans,fb[i]);
}
fo(i,1,n) c[i]=0;
fo(i,2,n*2){
l=i-fb[i];
if (l%2==1) l++;
l/=2;
c[l]=max(c[l],fb[i]);
}
k=0;
fo(i,1,n){
if (k) k-=2;
k=max(k,c[i]);
sb[i]=k;
}
tot=last=1;
fo(i,0,25) g[1][i]=0;
fo(i,1,n){
add(s1[i*2]-'a');
f[last]++;
if (i<n) mxa[last]=sa[i+1];
}
last=1;
fo(i,1,n){
add(s2[i*2]-'a');
f[last]+=2;
if (i<n) mxb[last]=sb[i+1];
}
fo(i,0,tot) sum[i]=0;
fo(i,1,tot) sum[step[i]]++;
fo(i,1,tot) sum[i]+=sum[i-1];
fd(i,tot,1) a[sum[step[i]]--]=i;
fd(i,tot,2){
j=a[i];
f[fail[j]]|=f[j];
mxa[fail[j]]=max(mxa[fail[j]],mxa[j]);
mxb[fail[j]]=max(mxb[fail[j]],mxb[j]);
}
fo(i,2,tot)
if (f[i]==3) ans=max(ans,step[i]*2+max(mxa[i],mxb[i]));
printf("%d\n",ans);
mx=max(mx,tot);
}
}