http://acm.hdu.edu.cn/showproblem.php?pid=5769
根据http://blog.youkuaiyun.com/viphong/article/details/52098859可以容易求到一个字符串里不同的子串个数,
而本题要求的是 包含X的字符串,
那么原来的公式是∑1lengthlength−(sa[i]+height[i])
就可以变成∑1lengthlength−max(nxt[sa[i]],sa[i]+height[i])
意思就是对于一个后缀,其有效的前缀原本为 n-(sa[i]+height[i]),但是考虑要必须包含字符x,则找到离sa[i]
最近的一个含X的位置, 有效前缀至多不超过n-nex[sa[i]]
两者取min
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<vector>
#include<algorithm>
using namespace std;
const int N = 100000+50;
int cmp(int *r,int a,int b,int l)
{
return (r[a]==r[b]) && (r[a+l]==r[b+l]);
}
int wa[N],wb[N],ws[N],wv[N];
int Rank[N],height[N];
void DA(int *r,int *sa,int n,int m) //此处N比输入的N要多1,为人工添加的一个字符,用于避免CMP时越界
{
int i,j,p,*x=wa,*y=wb,*t;
for(i=0; i<m; i++) ws[i]=0;
for(i=0; i<n; i++) ws[x[i]=r[i]]++;
for(i=1; i<m; i++) ws[i]+=ws[i-1];
for(i=n-1; i>=0; i--) sa[--ws[x[i]]]=i; //预处理长度为1
for(j=1,p=1; p<n; j*=2,m=p) //通过已经求出的长度J的SA,来求2*J的SA
{
for(p=0,i=n-j; i<n; i++) y[p++]=i; // 特殊处理没有第二关键字的
for(i=0; i<n; i++) if(sa[i]>=j) y[p++]=sa[i]-j; //利用长度J的,按第二关键字排序
for(i=0; i<n; i++) wv[i]=x[y[i]];
for(i=0; i<m; i++) ws[i]=0;
for(i=0; i<n; i++) ws[wv[i]]++;
for(i=1; i<m; i++) ws[i]+=ws[i-1];
for(i=n-1; i>=0; i--) sa[--ws[wv[i]]]=y[i]; //基数排序部分
for(t=x,x=y,y=t,p=1,x[sa[0]]=0,i=1; i<n; i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++; //更新名次数组x[],注意判定相同的
}
}
void calheight(int *r,int *sa,int n) // 此处N为实际长度
{
int i,j,k=0; // height[]的合法范围为 1-N, 其中0是结尾加入的字符
for(i=1; i<=n; i++) Rank[sa[i]]=i; // 根据SA求RANK
for(i=0; i<n; height[Rank[i++]] = k ) // 定义:h[i] = height[ Rank[i] ]
for(k?k--:0,j=sa[Rank[i]-1]; r[i+k]==r[j+k]; k++); //根据 h[i] >= h[i-1]-1 来优化计算height过程
}
int n;
char ss[N];
char mode[10];
int aa[N];
int sa[N];
int pos[N];
int nex[N];
long long solve()
{
DA(aa,sa,n+1,30);
calheight(aa,sa,n);
int p=-1;
for (int i=n-1; i>=0; i--)
{
if (ss[i]==mode[0])
p=i;
if (p>=0)
pos[Rank[i]]=p;
else pos[Rank[i]]=-1;
}
long long ans=0;
for (int i=1; i<=n; i++)
if (pos[i]!=-1)
ans+=n-max(pos[i],sa[i]+height[i]);
return ans;
}
int main ()
{
int t;
scanf("%d",&t);
int cnt=1;
while(t--)
{
scanf("%s%s",mode,ss);
n=strlen(ss);
for (int i=0; i<n; i++)
aa[i]=ss[i]-'a'+1;
aa[n]=0;
long long ans=solve();
printf("Case #%d: %lld\n",cnt++,ans);
}
return 0;
}