题意:给你n个字符串,和x, y, k, 问你从第x个字符串到第y个字符串中一共有多少个这些字符串的子串和第k个字符串相等。
思路:参考题解, 也可以等价转换为求和第k个字符串相等的子串有多少,根据后缀数组定义我们可以在我们可以求所有和s[k]相等的子串的数量,定义母串为所有字符串拼起的字符串(因为sa[i]是按照字典序排序的结果, 我们只需以s[k]母串中的位置为中心P,分别求出
L
C
P
(
L
,
P
)
=
=
l
e
n
(
s
[
k
]
)
,
L
C
P
(
P
,
R
)
=
=
l
e
n
(
s
[
k
]
)
LCP(L, P) == len(s[k]), LCP(P, R) == len(s[k])
LCP(L,P)==len(s[k]),LCP(P,R)==len(s[k])即可, 这样
L
到
R
L到R
L到R间的公共子串长度均大于等于
l
e
n
(
s
[
k
]
)
len(s[k])
len(s[k]), 且公共子串的某个前缀都是等于s[k]的), 但是这是所有字符串的和s[k]相同的子串, 如果我们想实现区间查询, 那么就要借助于线段树或者主席树,这里用的线段树。
用线段树来维护区间查询,
T
[
x
]
T[x]
T[x] 储存着一个区间(母串的l~r)内所有的的sa[i]所属的字符串的下标并对T[x]排序,这样在我们上面得到L和R后在线段树上查询对应区间的节点, 然后在该节点上利用二分查找查找到字符串下标为x~y的数量。
更多细节见于代码
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
#define ll long long
const int N = 500010;
vector<int> T[N<<2];
int n, m;
int len[N], start[N], belong[N]; // len[i]标记每个串的长度, start 标记每个串在合成串中的位置
int s[N];
char buf[N];
int sa[N], x[N], y[N], c[N], rk[N], height[N], base[N], f[N][30];
void get_sa()
{
for (int i = 1; i <= n; i ++ ) c[x[i] = s[i]] ++ ;
for (int i = 2; i <= m; i ++ ) c[i] += c[i - 1];
for (int i = n; i; i -- ) sa[c[x[i]] -- ] = i;
for (int k = 1; k <= n; k <<= 1)
{
int num = 0;
for (int i = n - k + 1; i <= n; i ++ ) y[ ++ num] = i;
for (int i = 1; i <= n; i ++ )
if (sa[i] > k)
y[ ++ num] = sa[i] - k;
for (int i = 1; i <= m; i ++ ) c[i] = 0;
for (int i = 1; i <= n; i ++ ) c[x[i]] ++ ;
for (int i = 2; i <= m; i ++ ) c[i] += c[i - 1];
for (int i = n; i; i -- ) sa[c[x[y[i]]] -- ] = y[i], y[i] = 0;
swap(x, y);
x[sa[1]] = 1, num = 1;
for (int i = 2; i <= n; i ++ )
x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? num : ++ num;
if (num == n) break;
m = num;
}
}
void get_height()
{
for (int i = 1; i <= n; i ++ ) rk[sa[i]] = i;
for (int i = 1, k = 0; i <= n; i ++ )
{
if (rk[i] == 1) continue;
if (k) k -- ;
int j = sa[rk[i] - 1];
while (i + k <= n && j + k <= n && s[i + k] == s[j + k]) k ++ ;
height[rk[i]] = k;
}
}
void init_rmq()
{
base[0] = -1;
for(int i = 1; i <= n; i ++)
{
f[i][0] = height[i];
base[i] = base[i>>1] + 1;
}
for(int j = 1; (1 << j) < n; j ++)
{
for(int i = 1; i + (1 << (j - 1)) <= n; i++)
{
f[i][j] = min(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}
}
}
int LCP(int x, int y) //排名第x和排名第y的最长公共前缀
{
int t = base[y - x + 1];
return min(f[x][t], f[y - (1 << t) + 1][t]);
}
void init()
{
memset(c, 0, sizeof c);
memset(x, 0, sizeof x);
}
int get_L(int x, int l, int r) // 二分在sa[i]上和第k串的最右边的公共前缀大于等于len(s[k])的下标
{
int R = r;
while(l < r)
{
int mid = l + r >> 1;
if(LCP(mid + 1, R) >= x) r = mid;
else l = mid + 1;
}
return l;
}
int get_R(int x, int l, int r)
{
int L = l;
while(l < r)
{
int mid = (l + r + 1) >> 1;
if(LCP(L + 1, mid) >= x) l = mid;
else r = mid - 1;
}
return r;
}
int maxx = 1e9 + 10;
void build(int x, int l, int r)
{
T[x].clear();
//cout << l << " " << r << endl;
for(int i = l; i <= r; i ++) T[x].push_back(belong[sa[i]]);
T[x].push_back(maxx);
sort(T[x].begin(), T[x].end());
if(l == r) return;
int mid = l + r >> 1;
build(x << 1, l, mid);
build(x << 1|1, mid + 1, r);
}
int query(int L, int R, int x, int y, int o, int l, int r)
{
//cout << l << " " << r << endl;
if(L <= l && r <= R)
{
int t1 = (lower_bound(T[o].begin(), T[o].end(), x) - T[o].begin());
int t2 = (lower_bound(T[o].begin(), T[o].end(), y + 1) - T[o].begin());
return t2 - t1;
}
int mid = l + r >> 1;
int sum = 0;
if(L <= mid) sum += query(L, R, x, y, o << 1, l, mid);
if(R > mid) sum += query(L, R, x, y, o << 1 | 1, mid + 1, r);
return sum;
}
int main()
{
int q;
while(~scanf("%d%d", &n, &q))
{
int n1 = 0;
int n2 = 27;
int cnt = 0;
init();
int sum = 0;
for(int i = 1; i <= n; i ++)
{
scanf("%s", buf);
len[i] = strlen(buf);
start[i] = cnt + 1;
for(int j = 0; j < len[i]; j ++)
{
belong[++cnt] = i;
s[cnt] = buf[j] - 'a' + 1;
}
s[++cnt] = n2 ++;
belong[cnt] = 0;
}
s[cnt --] = 0;
m = n2 + 1;
n = cnt;
get_sa();
get_height();
init_rmq();
build(1, 1, n);
while(q --)
{
int x, y, k;
scanf("%d%d%d", &x, &y, &k);
// cout << rk[start[k]] << " ** " << endl;
int L = get_L(len[k], 1, rk[start[k]]);
int R = get_R(len[k], rk[start[k]], n);
//cout << L << " " << R << endl;
printf("%d\n", query(L, R, x, y, 1, 1, n));
}
}
return 0;
}