【题目链接】
【思路要点】
- 分出的区间应当头、尾元素均为\(s_0\),否则可以使不是\(s_0\)的元素自成一段来使答案更优。
- 因此,我们将每个位置按照\(s_i\)分类,分别处理。
- 考虑\(i<j\)且\(s_i=s_j\),一旦在某个位置\(k\),决策点\(i\)优于决策点\(j\),那么决策点\(i\)就会始终优于决策点\(j\),这是由于转移方程中出现次数上具有平方,因此较靠前的决策点增长较快。
- 因此,我们对每种\(s_i\)维护一个单调栈来维护一个有序的决策集合,并记录集合中相邻的两个决策点最优性发生变化的时刻即可。
- 时间复杂度\(O(NLogN)\)。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 100005; const int MAXM = 10005; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } long long f[MAXN]; int n, m, s[MAXN], home[MAXN], top[MAXN]; vector <int> pos[MAXM], q[MAXM], ql[MAXM], qr[MAXM]; long long calc(int from, int to) { return f[from - 1] + 1ll * s[to] * (home[to] - home[from] + 1) * (home[to] - home[from] + 1); } int main() { freopen("BZOJ4709.in", "r", stdin); freopen("BZOJ4709.out", "w", stdout); read(n); m = 0; for (int i = 1; i <= n; i++) { read(s[i]), pos[s[i]].push_back(i); home[i] = pos[s[i]].size() - 1; chkmax(m, s[i]); } for (int i = 1; i <= m; i++) { q[i].resize(pos[i].size()); ql[i].resize(pos[i].size()); qr[i].resize(pos[i].size()); top[i] = -1; } q[s[1]][++top[s[1]]] = 1; ql[s[1]][top[s[1]]] = 0; qr[s[1]][top[s[1]]] = pos[s[1]].size() - 1; for (int i = 1; i <= n; i++) { int col = s[i]; f[i] = calc(q[col][top[col]], i); if (i == n) break; col = s[i + 1]; while (top[col] != -1 && home[i + 1] > qr[col][top[col]]) top[col]--; if (top[col] == -1 || calc(i + 1, i + 1) > calc(q[col][top[col]], i + 1)) { while (top[col] != -1 && calc(i + 1, pos[col][qr[col][top[col]]]) >= calc(q[col][top[col]], pos[col][qr[col][top[col]]])) top[col]--; if (top[col] == -1) { top[col]++; q[col][top[col]] = i + 1; ql[col][top[col]] = home[i + 1]; qr[col][top[col]] = pos[col].size() - 1; } else { int pl = home[i + 1], pr = qr[col][top[col]]; while (pl < pr) { int mid = (pl + pr + 1) / 2; if (calc(i + 1, pos[col][mid]) >= calc(q[col][top[col]], pos[col][mid])) pl = mid; else pr = mid - 1; } ql[col][top[col]] = pl + 1; top[col]++; q[col][top[col]] = i + 1; ql[col][top[col]] = home[i + 1]; qr[col][top[col]] = pl; } } } writeln(f[n]); return 0; }