【题目链接】
【思路要点】
- 先写一发bitset交上去试试,时间复杂度\(O(\frac{NQ}{w})\),其中\(w=64\),结果是T了,大概要12s才能跑出满数据。
- 正解是分块FFT,令分块大小为\(Size\),进行\(O(\frac{N}{Size})\)次FFT,处理出\(O(\frac{N}{Size})\)个\(T\)的后缀与\(S\)的每个长度为\(|T|\)的子串能够匹配的位数,询问时调用整块的结果,块内暴力计算。
- 这样的时间复杂度是\(O(\frac{N^2LogN}{Size}+QSize)\),取\(Size=\frac{N}{Q}\sqrt{QLogN}\)时,可以获得渐进意义下最优复杂度\(O(N\sqrt{QLogN})\),笔者的程序取\(Size=8000\),运行时间为6.2s。
- 为什么这么慢呢?因为笔者每次会对0和1各做一次FFT,也就是说,FFT部分的常数被乘上了2。
- 但实际上有一种更加高妙的做法来解决01匹配问题,我们令0为1,1为-1,然后FFT。那么两个字符如果匹配,得数为1,否则为-1,FFT后对结果数组进行简单处理就可以得到上面两次FFT一样的效果。
- 笔者此时的程序取\(Size=7200\),运行时间为4.5s。
【代码】
/*Clever Single FFT 4.5s approximately*/ #include<bits/stdc++.h> using namespace std; const int MAXN = 524288; const int REAL = 200005; const int SIZE = 7200; const int MAXK = 205; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } namespace FFT { const double pi = acos(-1); struct point {double x, y; }; point operator + (point a, point b) {return (point) {a.x + b.x, a.y + b.y}; } point operator - (point a, point b) {return (point) {a.x - b.x, a.y - b.y}; } point operator * (point a, point b) {return (point) {a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x}; } point operator / (point a, double x) {return (point) {a.x / x, a.y / x}; } int N, Log, home[MAXN]; point tmp[MAXN]; void FFTinit() { for (int i = 0; i < N; i++) { int tmp = i, ans = 0; for (int j = 1; j <= Log; j++) { ans <<= 1; ans += tmp & 1; tmp >>= 1; } home[i] = ans; } } void FFT(point *a, int mode) { for (int i = 0; i < N; i++) if (home[i] < i) swap(a[i], a[home[i]]); for (int len = 2; len <= N; len <<= 1) { point delta = (point) {cos(2 * pi / len * mode), sin(2 * pi / len * mode)}; for (int i = 0; i < N; i += len) { point now = (point) {1, 0}; for (int j = i, k = i + len / 2; k < i + len; j++, k++) { point tmp = a[j]; point tnp = a[k] * now; a[j] = tmp + tnp; a[k] = tmp - tnp; now = now * delta; } } } if (mode == -1) { for (int i = 0; i < N; i++) a[i] = a[i] / (4 * N); } } void times(double *a, double *b, double *c, int limit) { N = 1, Log = 0; while (N <= 2 * limit) { N <<= 1; Log++; } for (int i = 0; i < limit; i++) tmp[i] = (point) {a[i] + b[i], a[i] - b[i]}; for (int i = limit; i < N; i++) tmp[i] = (point) {0, 0}; FFTinit(); FFT(tmp, 1); for (int i = 0; i < N; i++) tmp[i] = tmp[i] * tmp[i]; FFT(tmp, -1); for (int i = 0; i < N; i++) c[i] = tmp[i].x; } void times(int *a, int *b, int *c, int limit) { N = 1, Log = 0; while (N <= 2 * limit) { N <<= 1; Log++; } for (int i = 0; i < limit; i++) tmp[i] = (point) {(double) (a[i] + b[i]), (double) (a[i] - b[i])}; for (int i = limit; i < N; i++) tmp[i] = (point) {0, 0}; FFTinit(); FFT(tmp, 1); for (int i = 0; i < N; i++) tmp[i] = tmp[i] * tmp[i]; FFT(tmp, -1); for (int i = 0; i < N; i++) c[i] = (int) (tmp[i].x + 0.5); } void times(long long *a, long long *b, long long *c, int limit) { N = 1, Log = 0; while (N <= 2 * limit) { N <<= 1; Log++; } for (int i = 0; i < limit; i++) tmp[i] = (point) {(double) (a[i] + b[i]), (double) (a[i] - b[i])}; for (int i = limit; i < N; i++) tmp[i] = (point) {0, 0}; FFTinit(); FFT(tmp, 1); for (int i = 0; i < N; i++) tmp[i] = tmp[i] * tmp[i]; FFT(tmp, -1); for (int i = 0; i < N; i++) c[i] = (long long) (tmp[i].x + 0.5); } } int ls, lt, tot; char s[MAXN], t[MAXN]; int ans[MAXK][REAL]; int a[MAXN], b[MAXN], c[MAXN]; int index[MAXN], l[MAXN], r[MAXN]; int query(int ps, int pt) { int tans = 0; if (ps + SIZE >= ls || pt + SIZE >= lt) { while (ps < ls && pt < lt) tans += s[ps++] == t[pt++]; return tans; } while (index[pt] == index[pt - 1]) tans += s[ps++] == t[pt++]; tans += ans[index[pt]][ps]; return tans; } int main() { scanf("\n%s\n%s", s, t); ls = strlen(s); lt = strlen(t); for (int i = 0; i < lt; i++) { if (i % SIZE == 0) l[++tot] = i; index[i] = tot; r[tot] = i; } for (int p = 1; p <= tot; p++) { memset(a, 0, sizeof(a)); memset(b, 0, sizeof(b)); for (int i = 0; i < ls; i++) if (s[i] == '0') a[i] = 1; else a[i] = -1; for (int i = l[p]; i < lt; i++) if (t[i] == '0') b[lt - 1 - i] = 1; else b[lt - 1 - i] = -1; int len = lt - l[p] - 1; int limit = max(ls, lt - l[p]); FFT::times(a, b, c, limit); for (int i = 0; i < ls; i++) ans[p][i] += (c[i + len] + len + 1 - max(0, i + len - ls + 1)) / 2; } int q; read(q); while (q--) { int ps, pt, len; read(ps), read(pt), read(len); writeln(len - query(ps, pt) + query(ps + len, pt + len)); } return 0; } /*Simple FFT Twice 6.2s approximately*/ #include<bits/stdc++.h> using namespace std; const int MAXN = 524288; const int REAL = 200005; const int SIZE = 8000; const int MAXK = 205; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } namespace FFT { const double pi = acos(-1); struct point {double x, y; }; point operator + (point a, point b) {return (point) {a.x + b.x, a.y + b.y}; } point operator - (point a, point b) {return (point) {a.x - b.x, a.y - b.y}; } point operator * (point a, point b) {return (point) {a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x}; } point operator / (point a, double x) {return (point) {a.x / x, a.y / x}; } int N, Log, home[MAXN]; point tmp[MAXN]; void FFTinit() { for (int i = 0; i < N; i++) { int tmp = i, ans = 0; for (int j = 1; j <= Log; j++) { ans <<= 1; ans += tmp & 1; tmp >>= 1; } home[i] = ans; } } void FFT(point *a, int mode) { for (int i = 0; i < N; i++) if (home[i] < i) swap(a[i], a[home[i]]); for (int len = 2; len <= N; len <<= 1) { point delta = (point) {cos(2 * pi / len * mode), sin(2 * pi / len * mode)}; for (int i = 0; i < N; i += len) { point now = (point) {1, 0}; for (int j = i, k = i + len / 2; k < i + len; j++, k++) { point tmp = a[j]; point tnp = a[k] * now; a[j] = tmp + tnp; a[k] = tmp - tnp; now = now * delta; } } } if (mode == -1) { for (int i = 0; i < N; i++) a[i] = a[i] / (4 * N); } } void times(double *a, double *b, double *c, int limit) { N = 1, Log = 0; while (N <= 2 * limit) { N <<= 1; Log++; } for (int i = 0; i < limit; i++) tmp[i] = (point) {a[i] + b[i], a[i] - b[i]}; for (int i = limit; i < N; i++) tmp[i] = (point) {0, 0}; FFTinit(); FFT(tmp, 1); for (int i = 0; i < N; i++) tmp[i] = tmp[i] * tmp[i]; FFT(tmp, -1); for (int i = 0; i < N; i++) c[i] = tmp[i].x; } void times(int *a, int *b, int *c, int limit) { N = 1, Log = 0; while (N <= 2 * limit) { N <<= 1; Log++; } for (int i = 0; i < limit; i++) tmp[i] = (point) {(double) (a[i] + b[i]), (double) (a[i] - b[i])}; for (int i = limit; i < N; i++) tmp[i] = (point) {0, 0}; FFTinit(); FFT(tmp, 1); for (int i = 0; i < N; i++) tmp[i] = tmp[i] * tmp[i]; FFT(tmp, -1); for (int i = 0; i < N; i++) c[i] = (int) (tmp[i].x + 0.5); } void times(long long *a, long long *b, long long *c, int limit) { N = 1, Log = 0; while (N <= 2 * limit) { N <<= 1; Log++; } for (int i = 0; i < limit; i++) tmp[i] = (point) {(double) (a[i] + b[i]), (double) (a[i] - b[i])}; for (int i = limit; i < N; i++) tmp[i] = (point) {0, 0}; FFTinit(); FFT(tmp, 1); for (int i = 0; i < N; i++) tmp[i] = tmp[i] * tmp[i]; FFT(tmp, -1); for (int i = 0; i < N; i++) c[i] = (long long) (tmp[i].x + 0.5); } } int ls, lt, tot; char s[MAXN], t[MAXN]; int ans[MAXK][REAL]; int a[MAXN], b[MAXN], c[MAXN]; int index[MAXN], l[MAXN], r[MAXN]; int query(int ps, int pt) { int tans = 0; if (ps + SIZE >= ls || pt + SIZE >= lt) { while (ps < ls && pt < lt) tans += s[ps++] == t[pt++]; return tans; } while (index[pt] == index[pt - 1]) tans += s[ps++] == t[pt++]; tans += ans[index[pt]][ps]; return tans; } int main() { scanf("\n%s\n%s", s, t); ls = strlen(s); lt = strlen(t); for (int i = 0; i < lt; i++) { if (i % SIZE == 0) l[++tot] = i; index[i] = tot; r[tot] = i; } for (int p = 1; p <= tot; p++) { memset(a, 0, sizeof(a)); memset(b, 0, sizeof(b)); for (int i = 0; i < ls; i++) a[i] = s[i] == '0'; for (int i = l[p]; i < lt; i++) b[lt - 1 - i] = t[i] == '0'; int len = lt - l[p] - 1; int limit = max(ls, lt - l[p]); FFT::times(a, b, c, limit); for (int i = 0; i < ls; i++) ans[p][i] += c[i + len]; memset(a, 0, sizeof(a)); memset(b, 0, sizeof(b)); for (int i = 0; i < ls; i++) a[i] = s[i] == '1'; for (int i = l[p]; i < lt; i++) b[lt - 1 - i] = t[i] == '1'; len = lt - l[p] - 1; limit = max(ls, lt - l[p]); FFT::times(a, b, c, limit); for (int i = 0; i < ls; i++) ans[p][i] += c[i + len]; } int q; read(q); while (q--) { int ps, pt, len; read(ps), read(pt), read(len); writeln(len - query(ps, pt) + query(ps + len, pt + len)); } return 0; } /*Bitset Version 12s approximately, TLE*/ #include<bits/stdc++.h> using namespace std; const int MAXN = 200005; const int bit = 64; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } int ls, lt; char s[MAXN], t[MAXN]; bitset <bit> masks[MAXN], maskt[MAXN]; int main() { scanf("\n%s\n%s", s, t); ls = strlen(s); lt = strlen(t); for (int i = 0; i <= ls - bit; i++) for (int j = 0; j < bit; j++) if (s[i + j] == '1') masks[i].set(j); else masks[i].reset(j); for (int i = 0; i <= lt - bit; i++) for (int j = 0; j < bit; j++) if (t[i + j] == '1') maskt[i].set(j); else maskt[i].reset(j); int q; read(q); while (q--) { register int ps, pt, len, ans = 0; read(ps), read(pt), read(len); while (len >= bit) { ans += (masks[ps] ^ maskt[pt]).count(); ps += bit, pt += bit, len -= bit; } for (int i = 0; i < len; i++) ans += s[ps + i] != t[pt + i]; writeln(ans); } return 0; }