【题目链接】
【思路要点】
- 补档博客,无题解。
【代码】
#include<bits/stdc++.h> using namespace std; #define MAXN 200005 #define MAXV 30005 #define MAXM 405 #define P acos(-1) template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } struct point {long double r, i; }; point operator + (point a, point b) {return (point) {a.r + b.r, a.i + b.i}; } point operator - (point a, point b) {return (point) {a.r - b.r, a.i - b.i}; } point operator * (point a, point b) {return (point) {a.r * b.r - a.i * b.i, a.r * b.i + b.r * a.i}; } int N, Log, pr[MAXN]; point a[MAXN], res[MAXN]; int n, v, m, len, value[MAXN]; int precnt[MAXN], sufcnt[MAXN], tmpcnt[MAXN]; int l[MAXM], r[MAXM]; long long ans; void BRCinit() { for (int i = 0; i < N; i++) { int tmp = i, ans = 0; for (int j = 1; j <= Log; j++) { ans <<= 1; ans += tmp & 1; tmp >>= 1; } pr[i] = ans; } } void BRC(point *a) { for (int i = 0; i < N; i++) if (pr[i] > i) swap(a[i], a[pr[i]]); } void FFT(point *a, int type) { BRC(a); for (int len = 2, half = 1; len <= N; len <<= 1, half <<= 1) { point delta = (point) {cos(type * 2 * P / len), sin(type * 2 * P / len)}; for (int start = 0; start < N; start += len) { point now = (point) {1, 0}; for (int i = start, j = start + half; i < start + half; i++, j++) { point tmp = a[i]; point tnp = a[j] * now; a[i] = tmp + tnp; a[j] = tmp - tnp; now = now * delta; } } } if (type == -1) { for (int i = 0; i < N; i++) a[i].r /= 4 * N; } } int main() { read(n); for (int i = 1; i <= n; i++) { read(value[i]); v = max(v, value[i]); sufcnt[value[i]]++; } len = min(n, (int) (6.5 * sqrt(n))); for (int i = 1; i <= n; i += len) { m++; l[m] = i; r[m] = min(n, i + len - 1); } l[0] = l[m + 1] = 1; N = 1, Log = 0; while (N <= 2 * v) { N <<= 1; Log++; } BRCinit(); for (int j = 1; j <= m; j++) { for (int i = l[j]; i <= r[j]; i++) sufcnt[value[i]]--; for (int i = 0; i < N; i++) a[i] = (point) {precnt[i] + sufcnt[i], precnt[i] - sufcnt[i]}; FFT(a, 1); for (int i = 0; i < N; i++) res[i] = a[i] * a[i]; FFT(res, -1); for (int i = l[j]; i <= r[j]; i++) ans += (long long) (res[value[i] * 2].r + 0.5); for (int i = l[j]; i <= r[j]; i++) { for (int k = i + 1; k <= r[j]; k++) { int tmp; tmp = value[i] * 2 - value[k]; if (tmp >= 0) ans += precnt[tmp]; tmp = value[k] * 2 - value[i]; if (tmp >= 0) ans += sufcnt[tmp]; tmp = value[i] + value[k]; if (tmp % 2 == 0) ans += tmpcnt[tmp / 2]; tmpcnt[value[k]]++; } for (int k = i + 1; k <= r[j]; k++) tmpcnt[value[k]]--; } for (int i = l[j]; i <= r[j]; i++) precnt[value[i]]++; } cout << ans << endl; return 0; }