题目链接:点我
题意:求出长度不小于k的公共子串个数
思路:本来想套个二分查询的,但是发现真的难搞,转移也不知道错哪了,用单调栈写,发现转移也不好写,于是基本都是借鉴的…
下面是代码
#include <set>
#include <map>
#include <stack>
#include <queue>
#include <vector>
#include <string>
#include <math.h>
#include <stdio.h>
#include <string.h>
#include <iostream>
#include <algorithm>
#define Fi first
#define Se second
#define ll long long
#define inf 0x3f3f3f3f
#define lowbit(x) (x&-x)
#define in(b) scanf("%d",&b)
#define mmin(a,b,c) min(a,min(b,c))
#define mmax(a,b,c) max(a,max(b,c))
#define debug(a) cout<<#a<<"="<<a<<endl;
#define debug2(a,b) cout<<#a<<"="<<a<<" "<<#b<<"="<<b<<endl;
#define debug3(a,b,c) cout<<#a<<"="<<a<<" "<<#b<<"="<<b<<" "<<#c<<"="<<c<<endl;
#define show_time cout << "The run time is:" << (double)clock() /CLOCKS_PER_SEC<< "s" << endl;
using namespace std;
const int N=3e5+10;
char s[N];
int x[N],y[N],sa[N],height[N],rnk[N],c[N];
void get_c(int n,int m)
{
for(int i=0;i<=m;i++) c[i]=0;
for(int i=1;i<=n;i++) c[x[i]]++;
for(int i=1;i<=m;i++) c[i]+=c[i-1];
}
void get_sa(int n,int m)
{
for(int i=1;i<=n;i++) x[i]=s[i];
get_c(n,m);
for(int i=n;i>=1;i--) sa[c[x[i]]--]=i;
for(int k=1;k<=n;k<<=1)
{
int num=0;
for(int i=n-k+1;i<=n;i++) y[++num]=i;
for(int i=1;i<=n;i++) if(sa[i]>k) y[++num]=sa[i]-k;
get_c(n,m);
for(int i=n;i>=1;i--) sa[c[x[y[i]]]--]=y[i],y[i]=0;
swap(x,y);num=0;
for(int i=1;i<=n;i++) x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k])? num : ++num;
if(num==n) break;
m=num;
}
}
void get_height(int n)
{
for(int i=1;i<=n;i++) rnk[sa[i]]=i;
int k=0;
for(int i=1;i<=n;i++)
{
if(rnk[i]==1) continue;
if(k) k--;
int j=sa[rnk[i]-1];
while(j+k<=n&&i+k<=n&&s[i+k]==s[j+k]) k++;
height[rnk[i]]=k;;
}
}
pair<int,int> num[N];
int main()
{
int k;
while(~in(k)&&k)
{
scanf("%s",s+1);
int len1=strlen(s+1);
s[len1+1]='$';
scanf("%s",s+2+len1);
int n=strlen(s+1);
get_sa(n,255);
get_height(n);
int top=0;
ll s=0,ans=0;
for(int i=2;i<=n;i++)
{
if(height[i]<k) {top=s=0;continue;}
int cnt=0;
while(top>0&&num[top].Fi>=height[i])
{
cnt+=num[top].Se;
s-=num[top].Se*(num[top].Fi-height[i]);
top--;
}
if(sa[i-1]<=len1) cnt++,s+=height[i]-k+1;
if(sa[i]>len1+1)
{
// debug2(s,1)
ans+=s;
}
num[++top]={height[i],cnt};
}
top=s=0;
for(int i=2;i<=n;i++)
{
if(height[i]<k) {top=s=0;continue;}
int cnt=0;
while(top>0&&num[top].Fi>=height[i])
{
cnt+=num[top].Se;
s-=num[top].Se*(num[top].Fi-height[i]);
top--;
}
if(sa[i-1]>len1+1) cnt++,s+=height[i]-k+1;
if(sa[i]<=len1)
{
// debug2(s,2)
ans+=s;
}
num[++top]={height[i],cnt};
}
printf("%lld\n",ans);
}
return 0;
}
/*
3 1 2 1
a$a
$a 0
a 0
a$a 1
*/