可以发现,对于原串的每个长度>1的子串而言,将其除了最后一个字符之外反向接在其结尾,都是一个合法解。该解的长度一定是奇数。
对于原串的每个长度>2,且结尾两个字符相同的子串而言,将其除了最后两个字符之外反向接在其结尾,都是一个合法解。该解的长度一定是偶数。
于是在SAM上统计一下就可以了……非常容易,O(n)。
别忘了减去长度为1的子串,以及长度为2,且两个字符相等的子串数。
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
typedef long long ll;
#define MAXL 100000
#define MAXC 26
int v[2*MAXL+10],__next[2*MAXL+10],first[2*MAXL+10],e;
void AddEdge(int U,int V){
v[++e]=V;
__next[e]=first[U];
first[U]=e;
}
char s[MAXL+10];//文本串
int len/*文本串长度*/;
struct SAM{
int one_of_endpos[2*MAXL+10];
int n/*状态数0~n-1*/,maxlen[2*MAXL+10],minlen[2*MAXL+10],trans[2*MAXL+10][MAXC],slink[2*MAXL+10];
int new_state(int _maxlen,int _minlen,int _trans[],int _slink){
maxlen[n]=_maxlen;
minlen[n]=_minlen;
for(int i=0;i<MAXC;++i){
if(_trans==NULL){
trans[n][i]=-1;
}
else{
trans[n][i]=_trans[i];
}
}
slink[n]=_slink;
return n++;
}
int add_char(char ch,int u,int pos){
if(u==-1){
return new_state(0,0,NULL,-1);
}
int c=ch-'a';
int z=new_state(maxlen[u]+1,-1,NULL,-1);
one_of_endpos[z]=pos;
int v=u;
while(v!=-1 && trans[v][c]==-1){
trans[v][c]=z;
v=slink[v];
}
if(v==-1){//最简单的情况,suffix-path(u->S)上都没有对应字符ch的转移
minlen[z]=1;
slink[z]=0;
return z;
}
int x=trans[v][c];
if(maxlen[v]+1==maxlen[x]){//较简单的情况,不用拆分x
minlen[z]=maxlen[x]+1;
slink[z]=x;
return z;
}
int y=new_state(maxlen[v]+1,-1,trans[x],slink[x]);//最复杂的情况,拆分x
slink[y]=slink[x];
minlen[x]=maxlen[y]+1;
slink[x]=y;
minlen[z]=maxlen[y]+1;
slink[z]=y;
int w=v;
while(w!=-1 && trans[w][c]==x){
trans[w][c]=y;
w=slink[w];
}
minlen[y]=maxlen[slink[y]]+1;
return z;
}
void dfs(int U){
for(int i=first[U];i;i=__next[i]){
dfs(v[i]);
one_of_endpos[U]=one_of_endpos[v[i]];
}
}
void work_slink_tree(){
for(int i=1;i<n;++i){
AddEdge(slink[i],i);
}
dfs(0);
}
}sam;
ll ans;
bool vis[1001];
int main(){
// freopen("uestc.h.in","r",stdin);
scanf("%s",s);
len=strlen(s);
int U=sam.add_char(0,-1,0);
for(int i=0;i<len;++i){
U=sam.add_char(s[i],U,i);
}
sam.work_slink_tree();
for(int i=0;i<len;++i){
if(!vis[s[i]]){
vis[s[i]]=1;
--ans;
}
}
memset(vis,0,sizeof(vis));
for(int i=0;i<len-1;++i){
if(s[i]==s[i+1] && (!vis[s[i]])){
vis[s[i]]=1;
--ans;
}
}
for(int i=1;i<sam.n;++i){
ans+=(ll)(sam.maxlen[i]-sam.minlen[i]+1);
if(s[sam.one_of_endpos[i]]==s[sam.one_of_endpos[i]-1]){
ans+=(ll)(sam.maxlen[i]-max(2,sam.minlen[i])+1);
}
}
printf("%lld\n",ans);
return 0;
}