后缀数组 或 后缀自动机
后缀数组的做法 O(nlogn) : 串接起来。要求的就是后缀数组中不属于同一串的后缀的LCP。暴力做O(n2)。发现我们一直都在区间取min,插一个线段树,每次找出最小的然后分两边做即可。(从大到小做,用并查集维护也很好)。这个方法复杂度不是很优秀,SA常数比较大,需要一个SA的优化技巧 if(m >= n) break; 即如果已经排完了就没必要继续了。不加可能会T。
后缀自动机做法 O(n) : 建广义SAM。一个节点代表的子串的贡献就是(len[i]-len[fail[i]]) * sz[i][0] * sz[i][1]。即节点代表的串数量 * 在串1中出现次数 * 在串2中出现次数。
以下后缀数组做法,自动机的还没写。
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 400005
using namespace std;
namespace runzhe2000
{
typedef long long ll;
ll ans;
int n1, n2, s[N], n, t1[N], t2[N], sum[N], height[N], sa[N], rank[N];
struct item{int l, r;}q[N];
char s1[N], s2[N];
struct seg{int mi, mipos, cnt1, cnt2;}t[N*5];
void build(int x, int l, int r)
{
if(l == r)
{
t[x].mi = height[l];
t[x].mipos = l;
if(sa[l] <= n1) t[x].cnt1 = 1, t[x].cnt2 = 0;
else if(n1+2 <= sa[l]) t[x].cnt1 = 0, t[x].cnt2 = 1;
else t[x].cnt1 = t[x].cnt2 = 0;
return;
}
int mid = (l+r)>>1;
build(x<<1,l,mid); build(x<<1|1,mid+1,r);
if(t[x<<1].mi < t[x<<1|1].mi) t[x].mi = t[x<<1].mi, t[x].mipos = t[x<<1].mipos;
else t[x].mi = t[x<<1|1].mi, t[x].mipos = t[x<<1|1].mipos;
t[x].cnt1 = t[x<<1].cnt1 + t[x<<1|1].cnt1;
t[x].cnt2 = t[x<<1].cnt2 + t[x<<1|1].cnt2;
}
int query_cnt(int x, int l, int r, int ql, int qr, int type)
{
if(l > r)return 0;
if(ql <= l && r <= qr) return type ? t[x].cnt1 : t[x].cnt2;
int mid = (l+r)>>1, ret = 0;
if(ql <= mid) ret += query_cnt(x<<1,l,mid,ql,qr,type);
if(mid < qr) ret += query_cnt(x<<1|1, mid+1,r,ql,qr,type);
return ret;
}
int query_mi(int x, int l, int r, int ql, int qr)
{
if(ql <= l && r <= qr) return t[x].mipos;
int mid = (l+r)>>1, p1 = 0, p2 = 0;
if(ql <= mid) p1 = query_mi(x<<1,l,mid,ql,qr);
if(mid < qr) p2 = query_mi(x<<1|1,mid+1,r,ql,qr);
if(!p1 || !p2) return p1?p1:p2;
else if(height[p1] < height[p2]) return p1;
else return p2;
}
void SA()
{
int *x = t1, *y = t2, m = 30;
for(int i = 1; i <= n; i++) sum[x[i] = s[i]]++;
for(int i = 1; i <= m; i++) sum[i] += sum[i-1];
for(int i = n; i >= 1; i--) sa[sum[x[i]]--] = i;
for(int k = 1; k <= n; k <<= 1)
{
int p = 0;
for(int i = n-k+1; i <= n; i++) y[++p] = i;
for(int i = 1; i <= n; i++) if(sa[i] - k > 0) y[++p] = sa[i] - k;
for(int i = 1; i <= m; i++) sum[i] = 0;
for(int i = 1; i <= n; i++) sum[x[i]]++;
for(int i = 1; i <= m; i++) sum[i] += sum[i-1];
for(int i = n; i >= 1; i--) sa[sum[x[y[i]]]--] = y[i];
swap(x, y);
for(int i = 1; i <= n; i++) x[sa[i]] = x[sa[i-1]] + ((y[sa[i]] == y[sa[i-1]] && y[sa[i]+k] == y[sa[i-1]+k]) ? 0 : 1);
m = x[sa[n]];
if(m >= n) break; // 有力的优化
}
for(int i = 1; i <= n; i++) rank[sa[i]]= i;
for(int k = 0, i = 1; i <= n; height[rank[i++]] = (k?k--:k))
for(; s[i+k] == s[sa[rank[i]-1]+k]; k++);
}
void main()
{
scanf("%s%s",s1+1,s2+1);
n1 = strlen(s1+1); n2 = strlen(s2+1);
for(int i = 1; i <= n1; i++) s[++n] = s1[i] - 'a' + 1;
s[++n] = 28;
for(int i = 1; i <= n2; i++) s[++n] = s2[i] - 'a' + 1;
SA(); build(1,1,n);
q[0] = (item){1,n};
for(int head = 0, tail = 1; head < tail; head++)
{
int l = q[head].l, r = q[head].r;
int minpos = query_mi(1,1,n,l+1,r);
if(l < minpos-1)q[tail++] = (item){l, minpos-1};
if(minpos < r)q[tail++] = (item){minpos, r};
ans += (ll) query_cnt(1,1,n,l,minpos-1, 1) * query_cnt(1,1,n,minpos, r, 0) * height[minpos];
ans += (ll) query_cnt(1,1,n,l,minpos-1, 0) * query_cnt(1,1,n,minpos, r, 1) * height[minpos];
}
printf("%lld\n",ans);
}
}
int main()
{
runzhe2000::main();
}