题意:
给定两个串S和T,从S中找两个子串组成T(拼起来和T一模一样),两个子串可重叠,问有几种组合方法?
三种方法:
KMP+计数
直接用T去匹配S,匹配到T中的T[j]时,说明长度为j的前缀出现了,更新sum[j],然后通过递归计算sum[next[i]] += sum[i]。为什么要这样计算呢?因为当某个前缀是另一个前缀的后缀时,会出现少计算的情况,更新的时候只更新后面那个前缀的出现次数而不更新前面的前缀。比如aabaa,aa和aabaa两个前缀就会出现少计算的情况(就是计算到aabbaa的时候会跳过最后一段aa),最后需要cnt[2] += cnt[5]。
然后将S,T都反过来再算一次
最后用for (i = 1; i < len; ++i) ans += sum1[i] * sum2[len-i]
// whn6325689
// Mr.Phoebe
// http://blog.youkuaiyun.com/u013007900
#include <algorithm>
#include <iostream>
#include <iomanip>
#include <cstring>
#include <climits>
#include <complex>
#include <fstream>
#include <cassert>
#include <cstdio>
#include <bitset>
#include <vector>
#include <deque>
#include <queue>
#include <stack>
#include <ctime>
#include <set>
#include <map>
#include <cmath>
#include <functional>
#include <numeric>
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;
#define eps 1e-9
#define PI acos(-1.0)
#define INF 0x3f3f3f3f
#define LLINF 1LL<<62
#define speed std::ios::sync_with_stdio(false);
typedef long long ll;
typedef long double ld;
typedef pair<ll, ll> pll;
typedef complex<ld> point;
typedef pair<int, int> pii;
typedef pair<pii, int> piii;
typedef vector<int> vi;
#define CLR(x,y) memset(x,y,sizeof(x))
#define CPY(x,y) memcpy(x,y,sizeof(x))
#define clr(a,x,size) memset(a,x,sizeof(a[0])*(size))
#define cpy(a,x,size) memcpy(a,x,sizeof(a[0])*(size))
#define mp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define lowbit(x) (x&(-x))
#define MID(x,y) (x+((y-x)>>1))
#define ls (idx<<1)
#define rs (idx<<1|1)
#define lson ls,l,mid
#define rson rs,mid+1,r
#define root 1,1,n
template<class T>
inline bool read(T &n)
{
T x = 0, tmp = 1;
char c = getchar();
while((c < '0' || c > '9') && c != '-' && c != EOF) c = getchar();
if(c == EOF) return false;
if(c == '-') c = getchar(), tmp = -1;
while(c >= '0' && c <= '9') x *= 10, x += (c - '0'),c = getchar();
n = x*tmp;
return true;
}
template <class T>
inline void write(T n)
{
if(n < 0)
{
putchar('-');
n = -n;
}
int len = 0,data[20];
while(n)
{
data[len++] = n%10;
n /= 10;
}
if(!len) data[len++] = 0;
while(len--) putchar(data[len]+48);
}
//-----------------------------------
const int MAXN=100010;
ll ans,pre[MAXN],las[MAXN];
int next[MAXN];
char s[MAXN],t[MAXN];
void kmp_pre(char x[],int m,int next[])
{
next[0]=next[1]=0;
int i,j;
for(i=1;i<m;i++)
{
j=next[i];
while(j&&x[i]!=x[j]) j=next[j];
next[i+1]=(x[i]==x[j])? j+1 : 0;
}
}
void KMP_Count(char x[],int n,char y[],int m,ll sum[])
{
kmp_pre(y,m,next);
int i,j=0;
for(i=0;i<n;i++)
{
while(j && x[i]!=y[j]) j=next[j];
if(x[i]==y[j])
{
j++;
sum[j]++;
// printf("j=%d,sum=%lld\n",j,sum[j]);
}
}
for(i=m;i>=0;i--)
if(next[i])
sum[next[i]]+=sum[i];
}
int main()
{
// freopen("data.txt","r",stdin);
int T;
read(T);
while(T--)
{
CLR(pre,0);CLR(las,0);ans=0;
scanf("%s%s",s,t);
int n=strlen(s);
int m=strlen(t);
KMP_Count(s,n,t,m,pre);
reverse(s,s+n);reverse(t,t+m);
KMP_Count(s,n,t,m,las);
for(int i=0;i<m;i++)
ans+=pre[i]*las[m-i];
write(ans),putchar('\n');
}
return 0;
}
哈希+二分
枚举每个S中的位置,用哈希+二分的方式看看最多能够匹配多少长度的串,然后对于T,这一个区间内的每一个前缀都是可行的
反过来再做一次
// whn6325689
// Mr.Phoebe
// http://blog.youkuaiyun.com/u013007900
#include <algorithm>
#include <iostream>
#include <iomanip>
#include <cstring>
#include <climits>
#include <complex>
#include <fstream>
#include <cassert>
#include <cstdio>
#include <bitset>
#include <vector>
#include <deque>
#include <queue>
#include <stack>
#include <ctime>
#include <set>
#include <map>
#include <cmath>
#include <functional>
#include <numeric>
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;
#define eps 1e-9
#define PI acos(-1.0)
#define INF 0x3f3f3f3f
#define LLINF 1LL<<62
#define speed std::ios::sync_with_stdio(false);
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<ll, ll> pll;
typedef complex<ld> point;
typedef pair<int, int> pii;
typedef pair<pii, int> piii;
typedef vector<int> vi;
#define CLR(x,y) memset(x,y,sizeof(x))
#define CPY(x,y) memcpy(x,y,sizeof(x))
#define clr(a,x,size) memset(a,x,sizeof(a[0])*(size))
#define cpy(a,x,size) memcpy(a,x,sizeof(a[0])*(size))
#define mp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define lowbit(x) (x&(-x))
#define MID(x,y) (x+((y-x)>>1))
#define ls (idx<<1)
#define rs (idx<<1|1)
#define lson ls,l,mid
#define rson rs,mid+1,r
#define root 1,1,n
template<class T>
inline bool read(T &n)
{
T x = 0, tmp = 1;
char c = getchar();
while((c < '0' || c > '9') && c != '-' && c != EOF) c = getchar();
if(c == EOF) return false;
if(c == '-') c = getchar(), tmp = -1;
while(c >= '0' && c <= '9') x *= 10, x += (c - '0'),c = getchar();
n = x*tmp;
return true;
}
template <class T>
inline void write(T n)
{
if(n < 0)
{
putchar('-');
n = -n;
}
int len = 0,data[20];
while(n)
{
data[len++] = n%10;
n /= 10;
}
if(!len) data[len++] = 0;
while(len--) putchar(data[len]+48);
}
//-----------------------------------
const int MAXN=100010;
ll c[2][MAXN];
ull hl[MAXN],hr[MAXN],x[MAXN];
ull tl[MAXN],tr[MAXN];
char s[MAXN],t[MAXN];
void update(int d,int i,int v)
{
for(;i<MAXN;i+=lowbit(i))
c[d][i]+=v;
}
ll getsum(int d,int i)
{
ll ans=0;
for(;i;i-=lowbit(i))
ans+=c[d][i];
return ans;
}
int main()
{
int T;
read(T);
x[0]=1;
for(int i=1;i<MAXN;i++)
x[i]=x[i-1]*123;
while(T--)
{
scanf("%s%s",s,t);
int n=strlen(s),m=strlen(t);
hl[n]=0;
for(int i=n-1;i>=0;i--)
hl[i]=hl[i+1]*123+s[i];
hr[0]=0;
for(int i=1;i<=n;i++)
hr[i]=hr[i-1]*123+s[i-1];
tl[m]=0;
for(int i=m-1;i>=0;i--)
tl[i]=tl[i+1]*123+t[i];
tr[0]=0;
for(int i=1;i<=m;i++)
tr[i]=tr[i-1]*123+t[i-1];
CLR(c,0);
for(int i=0;i<n;i++)
{
if(s[i]!=t[0]) continue;
int l=0,r=min(n-i,m)+1;
while(l+1<r)
{
int mid=MID(l,r);
ull A=hl[i]-hl[i+mid]*x[mid];
ull B=tl[0]-tl[mid]*x[mid];
if(A==B) l=mid;
else r=mid;
}
update(0,1,1);
update(0,l+1,-1);
}
for(int i=n-1;i>=0;i--)
{
if(s[i]!=t[m-1]) continue;
int l=0,r=min(i+1,m)+1;
while(l+1<r)
{
int mid=MID(l,r);
ull A=hr[i+1]-hr[i+1-mid]*x[mid];
ull B=tr[m]-tr[m-mid]*x[mid];
if(A==B) l=mid;
else r=mid;
}
update(1,1,1);
update(1,l+1,-1);
}
ll ans=0;
for(int i=1;i<m;i++)
ans+=getsum(0,i)*getsum(1,m-i);
write(ans),putchar('\n');
}
return 0;
}
扩展KMP
总体思路差不多
// whn6325689
// Mr.Phoebe
// http://blog.youkuaiyun.com/u013007900
#include <algorithm>
#include <iostream>
#include <iomanip>
#include <cstring>
#include <climits>
#include <complex>
#include <fstream>
#include <cassert>
#include <cstdio>
#include <bitset>
#include <vector>
#include <deque>
#include <queue>
#include <stack>
#include <ctime>
#include <set>
#include <map>
#include <cmath>
#include <functional>
#include <numeric>
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;
#define eps 1e-9
#define PI acos(-1.0)
#define INF 0x3f3f3f3f
#define LLINF 1LL<<62
#define speed std::ios::sync_with_stdio(false);
typedef long long ll;
typedef long double ld;
typedef pair<ll, ll> pll;
typedef complex<ld> point;
typedef pair<int, int> pii;
typedef pair<pii, int> piii;
typedef vector<int> vi;
#define CLR(x,y) memset(x,y,sizeof(x))
#define CPY(x,y) memcpy(x,y,sizeof(x))
#define clr(a,x,size) memset(a,x,sizeof(a[0])*(size))
#define cpy(a,x,size) memcpy(a,x,sizeof(a[0])*(size))
#define mp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define lowbit(x) (x&(-x))
#define MID(x,y) (x+((y-x)>>1))
#define ls (idx<<1)
#define rs (idx<<1|1)
#define lson ls,l,mid
#define rson rs,mid+1,r
#define root 1,1,n
template<class T>
inline bool read(T &n)
{
T x = 0, tmp = 1;
char c = getchar();
while((c < '0' || c > '9') && c != '-' && c != EOF) c = getchar();
if(c == EOF) return false;
if(c == '-') c = getchar(), tmp = -1;
while(c >= '0' && c <= '9') x *= 10, x += (c - '0'),c = getchar();
n = x*tmp;
return true;
}
template <class T>
inline void write(T n)
{
if(n < 0)
{
putchar('-');
n = -n;
}
int len = 0,data[20];
while(n)
{
data[len++] = n%10;
n /= 10;
}
if(!len) data[len++] = 0;
while(len--) putchar(data[len]+48);
}
//-----------------------------------
const int MAXN=100010;
void ExtenKMP(char *a,char *b,int M,int N,int *next,int *ret)
{
int i,j,k;
for(j=0;j+1<M&&a[j]==a[j+1];j++);
next[1]=j;
k=1;
for(int i=2;i<M;i++)
{
int Len=k+next[k],L=next[i-k];
if(L<Len-i)
next[i]=L;
else
{
for(j=max(0,Len-i);i+j<M&&a[j]==a[j+i];j++);
next[i]=j;
k=i;
}
}
for(j=0;j<N&&j<M&&a[j]==b[j];j++);
ret[0]=j;
k=0;
for(i=1;i<N;i++)
{
int Len=k+ret[k],L=next[i-k];
if(L<Len-i)
ret[i]=L;
else
{
for(j=max(0,Len-i);j<M&&i+j<N&&a[j]==b[i+j];j++);
ret[i]=j;
k=i;
}
}
}
char s[MAXN],t[MAXN];
int ret1[MAXN],ret2[MAXN],next[MAXN];
int main()
{
int T;
read(T);
while(T--)
{
scanf("%s%s",s,t);
int n=strlen(s),m=strlen(t);
ExtenKMP(t,s,m,n,next,ret1);
sort(ret1,ret1+n);
reverse(s,s+n);reverse(t,t+m);
ExtenKMP(t,s,m,n,next,ret2);
sort(ret2,ret2+n);
ll ans=0;
for(int i=1;i<m;i++)
{
int t1=lower_bound(ret1,ret1+n,i)-ret1;
int t2=lower_bound(ret2,ret2+n,m-i)-ret2;
ans+=(n-t1)*1LL*(n-t2);
}
write(ans),putchar('\n');
}
return 0;
}