题意:给定一个字符c和一个字符串s,问s中有多少个不同的子串包含c。
分析:比赛的时候对sa有些不熟了,想起看sam的时候记得sam能直接处理出所有不同的子串,然后就直接在卿神的博客上找了个sam的模板直接做了,因为对sam不熟还开小空间wa了一发。赛后补题的时候想起sa也能处理想着回顾一下sa也写了一下。思路:sa将所有后缀从小到大排好序之后能得出相邻两个后缀的公共前缀长度,每一个后缀从公共长度往后就是不同的子串啦,然后我们再确定从哪个位置开始才包含字符c,然后统计好就行了。(好像因为*26的原因,O(n)建的sam反而比O(nlogn)建的sa慢好多)。
代码:
#include<map>//sa
#include<set>
#include<cmath>
#include<queue>
#include<bitset>
#include<math.h>
#include<vector>
#include<string>
#include<stdio.h>
#include<cstring>
#include<iostream>
#include<algorithm>
#pragma comment(linker, "/STACK:102400000,102400000")
using namespace std;
const int N=100010;
const int mod=100000000;
const int MOD1=1000000007;
const int MOD2=1000000009;
const double EPS=0.00000001;
typedef long long ll;
const ll MOD=1000000007;
const int MAX=1000000010;
const ll INF=1ll<<55;
const double pi=acos(-1.0);
typedef double db;
typedef unsigned long long ull;
char ch[5],s[N];
int w[N],sa[N],ran[N],hei[N];
int cu[N],t1[N],t2[N],str[N];
void build_sa(int n,int m) {
int i,j,p,*x=t1,*y=t2;
for (i=0;i<m;i++) cu[i]=0;
for (i=0;i<n;i++) cu[x[i]=str[i]]++;
for (i=1;i<m;i++) cu[i]+=cu[i-1];
for (i=n-1;i>=0;i--) sa[--cu[x[i]]]=i;
for (p=0,j=1;j<=n;m=p,p=0,j<<=1) {
for (i=n-j;i<n;i++) y[p++]=i;
for (i=0;i<n;i++) if (sa[i]>=j) y[p++]=sa[i]-j;
for (i=0;i<m;i++) cu[i]=0;
for (i=0;i<n;i++) cu[x[y[i]]]++;
for (i=1;i<m;i++) cu[i]+=cu[i-1];
for (i=n-1;i>=0;i--) sa[--cu[x[y[i]]]]=y[i];
swap(x,y);
p=1;x[sa[0]]=0;
for (i=1;i<n;i++)
if (y[sa[i-1]]==y[sa[i]]&&y[sa[i-1]+j]==y[sa[i]+j]) x[sa[i]]=p-1;
else x[sa[i]]=p++;
if (p>=n) break ;
}
}
void getHeight(int n) {
int i,j,k=0;
for (i=0;i<=n;i++) ran[sa[i]]=i;
for (i=0;i<n;i++) {
if (k) k--;j=sa[ran[i]-1];
while (str[i+k]==str[j+k]) k++;
hei[ran[i]]=k;
}
}
int main()
{
int i,n,t,ca,len;
ll ans;
scanf("%d", &t);
for (ca=1;ca<=t;ca++) {
scanf("%s%s", ch, s);
len=strlen(s);str[len]=0;
for (i=0;i<len;i++) str[i]=s[i]-'a'+1;
build_sa(len+1,30);getHeight(len);
for (w[len]=len,i=len-1;i>=0;i--)
if (s[i]==ch[0]) w[i]=i;
else w[i]=w[i+1];
for (ans=0,i=1;i<=len;i++)
ans+=(ll)(len-max(sa[i]+hei[i],w[sa[i]]));
printf("Case #%d: %I64d\n", ca, ans);
}
return 0;
}
代码:
#include<map>//sam
#include<set>
#include<cmath>
#include<queue>
#include<bitset>
#include<math.h>
#include<vector>
#include<string>
#include<stdio.h>
#include<cstring>
#include<iostream>
#include<algorithm>
#pragma comment(linker, "/STACK:102400000,102400000")
using namespace std;
const int N=200010;
const int mod=100000000;
const int MOD1=1000000007;
const int MOD2=1000000009;
const double EPS=0.00000001;
typedef long long ll;
const ll MOD=1000000007;
const int MAX=2000000010;
const ll INF=1ll<<55;
const double pi=acos(-1.0);
typedef double db;
typedef unsigned long long ull;
char c[5],s[N];
int d[N*2],fa[N*2],p[N*2],dp[N*2];
int la,Now,pre[N*2],len[N*2],Son[N*2][26];
struct SAM
{
void add(int x)
{
int p=la,np=la=++Now;
len[np]=len[p]+1;
for(;p&&!Son[p][x];p=pre[p])
Son[p][x]=np;
if(!p)
pre[np]=1;
else
{
int q=Son[p][x];
if(len[q]==len[p]+1)
pre[np]=q;
else
{
int nq=++Now;
memcpy(Son[nq],Son[q],sizeof Son[nq]);
len[nq]=len[p]+1;
pre[nq]=pre[q];
pre[q]=pre[np]=nq;
for(;p&&Son[p][x]==q;p=pre[p])
Son[p][x]=nq;
}
}
}
}sam;
int main()
{
int i,j,t,ca,Len,l,r;
ll ans;
scanf("%d", &t);
for (ca=1;ca<=t;ca++) {
scanf("%s%s", c, s);
la=Now=1;
Len=strlen(s);
memset(pre,0,sizeof(pre));
memset(len,0,sizeof(len));
memset(Son,0,sizeof(Son));
for (i=0;i<Len;i++) sam.add(s[i]-'a');
memset(d,0,sizeof(d));
memset(fa,0,sizeof(fa));
for (i=1;i<=Now;i++)
for (j=0;j<26;j++)
if (Son[i][j]) {
fa[Son[i][j]]=j,d[Son[i][j]]++;
}
l=1;r=1;p[1]=1;
for (;l<=r;l++)
for (i=0;i<26;i++)
if (Son[p[l]][i]) {
d[Son[p[l]][i]]--;
if (d[Son[p[l]][i]]==0) p[++r]=Son[p[l]][i];
}
fa[1]=30;ans=0;
memset(d,0,sizeof(d));
memset(dp,0,sizeof(dp));
for (d[1]=1,i=1;i<=r;i++)
for (j=0;j<26;j++)
if (Son[p[i]][j]) d[Son[p[i]][j]]+=d[p[i]];
for (i=1;i<=r;i++) {
if (fa[p[i]]==c[0]-'a') dp[p[i]]=d[p[i]];
for (j=0;j<26;j++)
if (Son[p[i]][j]) dp[Son[p[i]][j]]+=dp[p[i]];
}
for (i=2;i<=r;i++) ans+=(ll)dp[p[i]];
printf("Case #%d: %I64d\n", ca, ans);
}
return 0;
}