题意:
给你一个s和t串,然后用s去匹配t串中的每一个长度与s相等的子串。如果 s i = t j s_i=t_j si=tj或者 p i d x ( s i ) = i d x ( t j ) p_{idx(s_i)}=idx(t_j) pidx(si)=idx(tj)则可以匹配,输出每一个位置可以匹配的情况。
做法:
显然是不能用常规的字符串匹配去解决。这里如果你做过洛谷上的带有通配符的字符串的匹配,可能知道怎么做,就是用FFT 解决。
带通配符的字符串匹配
这里我们构造一个式子,
(
s
i
−
t
j
)
2
(
p
j
−
t
j
)
2
(s_i-t_j)^2(p_j-t_j)^2
(si−tj)2(pj−tj)2表示i和j这两个位置匹配。
如果长度为m的匹配则
∑
j
=
0
m
−
1
(
s
j
−
t
i
+
j
)
2
(
p
j
−
t
i
+
j
)
2
\sum_{j=0}^{m-1}(s_{j}-t_{i+j})^2(p_{j}-t_{i+j})^2
∑j=0m−1(sj−ti+j)2(pj−ti+j)2有多少i的位置等于零,这个化简一下显然就是多项式相乘,用FFT就可以了。
我这里写得是NTT写得丑刚刚卡过去,如果WA126注意随机一下 idx,就可以了。
#include "bits/stdc++.h"
using namespace std;
#define VI vector<int>
#define ll long long
#define SZ(x) ((int)x.size())
#define all(x) x.begin(),x.end()
mt19937 mrand(random_device{}());
int rnd(int x) { return mrand() % x; }
const int maxn = 1 << 19;
const ll mod = 998244353;
int Mod(int x) {
if (x >= mod) x -= mod;
return x;
}
ll quick(ll a, ll n) {
ll ans = 1;
while (n) {
if (n & 1) ans = ans * a % mod;
n >>= 1;
a = a * a % mod;
}
return ans;
}
const ll g = 3;
int r[maxn], tot, lim, roots[33];
void ntt(int *a, int inv) {
for (int i = 0; i < tot; i++) {
if (i < r[i]) swap(a[i], a[r[i]]);
}
for (int l = 2, id = 1; l <= tot; l <<= 1, id++) {
int tmp = roots[id];
if (inv == -1) tmp = quick(tmp, mod - 2);
int m = l / 2;
for (int j = 0; j < tot; j += l) {
int w = 1;
for (int i = 0; i < m; i++) {
int t = 1ll * a[j + i + m] * w % mod;
a[j + i + m] = Mod(a[j + i] - t + mod);
a[j + i] = Mod(a[j + i] + t);
w = 1ll * w * tmp % mod;
}
}
}
if (inv == -1) {
int t = quick(tot, mod - 2);
for (int i = 0; i < tot; i++) {
a[i] = 1LL * a[i] * t % mod;
}
}
}
void init(int n, int m) {
tot = 1, lim = 0;
while (tot < n + m) tot <<= 1, lim++;
for (int i = 0; i < tot; i++) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lim - 1));
}
for (int i = 1; i <= lim; i++) {
int t = 1 << i;
roots[i] = quick(g, (mod - 1) / t);
}
}
int A[maxn], B[maxn], P[maxn];
vector<int> multiply(int *a, int *b, int n, int m) {
for (int i = 0; i < m; i++) B[i] = b[i];
for (int i = 0; i < n; i++) A[i] = a[i];
ntt(A, 1);
ntt(B, 1);
for (int i = 0; i < tot; i++) P[i] = 1ll * A[i] * B[i] % mod;
ntt(P, -1);
vector<int> ans(tot, 0);
for (int i = 0; i < tot; i++) {
ans[i] = P[i];
P[i] = A[i] = B[i] = 0;
}
return ans;
}
int n, p[30];
char s[maxn], t[maxn];
int c[5][maxn], d[5][maxn], res[maxn], val[maxn];
int main() {
// freopen("1.in", "r", stdin);
// double ss = clock();
for (int i = 0; i < 26; i++) val[i] = rnd(100);
for (int i = 0, x; i < 26; i++) {
scanf("%d", &x);
p[x - 1] = i;
}
scanf("%s%s", s, t);
n = strlen(t);
for (int i = 0; i < n; i++) {
int id = t[i] - 'a';
int x = val[id], y = val[p[id]];
c[4][i] = 1ll;
c[3][i] = (mod - 2 * x - 2 * y);
c[2][i] = x * x + y * y + 4 * x * y;
c[1][i] = mod - 2 * x * y * (x + y);
c[0][i] = x * x * y * y;
}
int m = strlen(s);
for (int i = 0; i < m; i++) {
int id = val[s[i] - 'a'];
d[4][i] = id * id * id * id;
d[3][i] = id * id * id;
d[2][i] = id * id;
d[1][i] = id;
d[0][i] = 1ll;
}
init(n, m);
for (int i = 0; i < 5; i++) {
reverse(d[i], d[i] + m);
vector<int> pp = multiply(c[i], d[i], n, m);
for (int j = m - 1; j < n; j++) res[j] = (res[j] + pp[j]) % mod;
}
for (int i = m - 1; i < n; i++) {
if (res[i] == 0)printf("1");
else printf("0");
}
puts("");
// cout << clock() - ss << endl;
return 0;
}