Educational Codeforces Round 70
E:
题意是给定
n
n
n 个字符串
s
i
s_i
si ,对于任意组合
s
i
+
s
j
s_i+s_j
si+sj ,求所有组合在
t
t
t 中的出现次数之和。
因为
s
i
s_i
si 和
s
j
s_j
sj 是相连的,因此可以在
t
t
t 上枚举连接点,然后求出以连接点右边第一个点为起始点有多少个
s
s
s,然后求出连接点左边第一个点为结束点有多少个
s
s
s,然后将这两个值相乘就是该连接点的贡献了。
因为将所有字符串翻转后,左边第一个点结束点有多少个
s
s
s 个问题转化为右边第一个点为起始点有多少个
s
s
s 的问题,因此我们只要求后一个问题即可。
可以将所有字符串按长度分类,长度小于
B
B
B 的字符串放到字符串中,长度大于
B
B
B 的字符串通过哈希或者kmp计算。将
B
B
B 设为
t
t
t 的根号长度后整体的复杂度就变成了
O
(
l
e
n
(
t
)
1.5
)
O(len(t)^{1.5})
O(len(t)1.5) 了。
除此之外,这个问题也是可以通过AC自动机来求的,复杂度就线性了。AC自动机在构建fail指针时需要将
c
[
n
o
w
]
+
=
c
[
f
a
i
l
[
n
o
w
]
]
c[now]+=c[fail[now]]
c[now]+=c[fail[now]] ,因为如果匹配到了
n
o
w
now
now ,说明也匹配到了
f
a
i
l
[
n
o
w
]
fail[now]
fail[now] 。然后直接将
t
t
t 在AC自动机上扫,对于
t
[
i
]
t[i]
t[i] ,对应的AC自动机上的状态
c
[
n
o
w
]
c[now]
c[now] 就是
s
s
s 中所有为
t
[
1..
i
]
t[1..i]
t[1..i] 的后缀的数量。这里就和根号做法的kmp很像,因为AC自动机就是多模kmp。
根号分块做法:
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5+7;
const int B = 300;
using ll = long long;
string t;
int c[N], revc[N], nxt[N];
vector<string> ls, ss;
struct Trie {
int nxt[N][26], tot, rt, cnt[N];
void init() {
tot = 0;
rt = newnode();
}
int newnode() {
++tot;
memset(nxt[tot], 0, sizeof(nxt[tot]));
cnt[tot] = 0;
return tot;
}
void insert(const string& s) {
int now = rt;
for(int i=0; i<s.length(); ++i) {
int ch = s[i]-'a';
if(nxt[now][ch]==0) nxt[now][ch]=newnode();
now = nxt[now][ch];
}
++cnt[now];
}
int query(const string& t, int start) {
int res = 0;
int now = rt;
for(int i=start; i<t.length(); ++i) {
int ch = t[i]-'a';
now = nxt[now][ch];
if(now==0) break;
res += cnt[now];
}
return res;
}
}trie;
void kmp_pre(const string &s) {
int i,j;
j=nxt[0]=-1;
i=0;
while(i<s.length()) {
while(-1!=j&&s[i]!=s[j]) j=nxt[j];
nxt[++i]=++j;
}
}
void kmp(const string &s, int *c) {
int i=0, j=0, ans=0;
kmp_pre(s);
while(i<t.length()) {
while(-1!=j&&t[i]!=s[j]) j=nxt[j];
++i; ++j;
if(j>=s.length()) {
++c[i-s.length()];
j=nxt[j];
}
}
}
void cal(int *c) {
// puts("cal");
trie.init();
for(string &s : ss) {
trie.insert(s);
}
for(int i=0; i<t.length(); ++i) {
c[i]=trie.query(t, i);
// printf("c[%d]=%d\n", i, c[i]);
}
for(string &s : ls) {
kmp(s, c);
}
// for(int i=0; i<t.length(); ++i) {
// printf("c[%d]=%d\n", i, c[i]);
// }
}
int main() {
ios::sync_with_stdio(false);
cin >> t;
int n;
cin >> n;
for(int i=0; i<n; ++i) {
string s;
cin >> s;
if(s.length()<B) ss.push_back(s);
else ls.push_back(s);
}
cal(c);
reverse(t.begin(), t.end());
for(string &s : ls) reverse(s.begin(), s.end());
for(string &s : ss) reverse(s.begin(), s.end());
cal(revc);
ll ans = 0;
for(int i=1; i<t.length(); ++i) {
ans += 1LL*c[i]*revc[t.length()-i];
// printf("%d %d\n", i, t.length()-i);
}
cout << ans<<endl;
}
AC自动机做法:
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5+7;
using ll = long long;
string t;
int c[N], revc[N], nxt[N];
vector<string> ls;
struct Trie {
int nxt[N][26], tot, rt, cnt[N], fail[N];
void init() {
tot = 0;
rt = newnode();
}
int newnode() {
++tot;
memset(nxt[tot], 0, sizeof(nxt[tot]));
cnt[tot] = 0;
return tot;
}
void insert(const string& s) {
int now = rt;
for(int i=0; i<s.length(); ++i) {
int ch = s[i]-'a';
if(nxt[now][ch]==0) nxt[now][ch]=newnode();
now = nxt[now][ch];
}
++cnt[now];
// printf("now: %d, c: %d\n", now, cnt[now]);
}
void build()
{
queue<int> q;
for(int i=0;i<26;++i)
{
if(nxt[rt][i]==0)
nxt[rt][i]=rt;
else
{
fail[nxt[rt][i]]=rt;
q.push(nxt[rt][i]);
}
}
while(!q.empty())
{
int now=q.front();q.pop();
cnt[now] += cnt[fail[now]];
for(int i=0;i<26;++i)
{
if(nxt[now][i]==0)
nxt[now][i]=nxt[fail[now]][i];
else
{
fail[nxt[now][i]]=nxt[fail[now]][i];
q.push(nxt[now][i]);
}
}
}
}
void solve(int *c) {
int now = rt;
for(int i=0; i<t.length(); ++i) {
int ch = t[i]-'a';
now = nxt[now][ch];
c[i] += cnt[now];
}
}
}trie;
void cal(int *c) {
// puts("cal");
trie.init();
for(string &s : ls) trie.insert(s);
trie.build();
// for(int i=1; i<=trie.tot; i++) printf("cnt[%d]=%d\n", i, trie.cnt[i]);
trie.solve(c);
// for(int i=0; i<t.length(); ++i) {
// printf("c[%d]=%d\n", i, c[i]);
// }
}
int main() {
ios::sync_with_stdio(false);
cin >> t;
int n;
cin >> n;
for(int i=0; i<n; ++i) {
string s;
cin >> s;
ls.push_back(s);
}
cal(c);
reverse(t.begin(), t.end());
for(string &s : ls) reverse(s.begin(), s.end());
cal(revc);
ll ans = 0;
for(int i=0; i+1<t.length(); ++i) {
ans += 1LL*c[i]*revc[t.length()-i-2];
// printf("%d %d\n", i, t.length()-i-2);
}
cout << ans<<endl;
}