题目意思是求给定一序列中所有不同连续子序列的个数。
对于一个长度为n的序列那么其总共的连续子序列的个数是(n+1)*n/2;所以只要用总共的减去重复的就可以了。如何计算重复的?还是要用到height数组,由于字典序排好的后缀,所以如果当前相邻字符串的公共前缀长度为k的话,那么只要用总数减取k就行了,比如一个串为abcabc在字典序中,adc与adcadc是相邻的,所以就说明了有长度为3的共同部分,所以要减去3.
我在这里想了一会,既然公共的前缀为abc为什么只减去3,明明abc代表了,6个子串,为什么不减去(k+1)*k/2。最后想明白了,这个3减去的是以a开头的子串,而相bc这样的会在 bc与bcabc这组的height中被处理掉,所以这样做刚刚好除去了所有的,且不重复。
#include <iostream>
#include <stdio.h>
#include <cstring>
#include <algorithm>
using namespace std;
#define Max_N (2000 + 100)
int n;
int k;
int a[Max_N];
int rank1[Max_N];
int tmp[Max_N];
bool compare_sa(int i, int j)
{
if(rank1[i] != rank1[j]) return rank1[i] < rank1[j];
else {
int ri = i + k <= n ? rank1[i + k] : -1;
int rj = j + k <= n ? rank1[j + k] : -1;
return ri < rj;
}
}
void construct_sa(int buf[], int s, int sa[])
{
int len = s;
for (int i = 0; i <= len; i++) {
sa[i] = i;
rank1[i] = i < len ? buf[i] : -1;
}
for ( k = 1; k <= len; k *= 2) {
sort(sa, sa + len +1, compare_sa);
tmp[sa[0]] = 0;
for (int i = 1; i <= len; i++) {
tmp[sa[i]] = tmp[sa[i-1]] + (compare_sa(sa[i-1], sa[i]) ? 1 : 0);
}
for (int i = 0; i <= len; i++) {
rank1[i] = tmp[i];
}
}
}
void construct_lcp(int buf[], int len, int *sa, int *lcp)
{
int h = 0;
lcp[0] = 0;
for (int i = 0; i < len; i++) {
int j = sa[rank1[i] - 1];
if (h > 0) h--;
for (; j + h < len && i + h < len; h++) {
if (buf[j+h] != buf[i+h]) break;
}
lcp[rank1[i] - 1] = h;
}
}
int sa[Max_N];
int rev[Max_N];
int lcp[Max_N];
char buf1[Max_N];
int main()
{
int T;
cin >> T;
while (T--) {
scanf ("%s", buf1);
n = strlen(buf1);
for (int i = 0; i < n; i++)
a[i] = int(buf1[i]);
construct_sa(a, n, sa);
construct_lcp(a, n, sa, lcp);
int ans = (n + 1) * n / 2;
for (int i = 0; i < n; i++)
ans -= lcp[i];
printf("%d\n", ans);
}
return 0;
}