【思路要点】
- 我们先来考虑这个问题在序列上的形式。
- 我们要将序列分成 kkk 段,使得每一段所有数到其中位数的距离之和最小。
- 由于代价函数 www 满足四边形不等式 w(i,k)+w(j,l)≤w(i,l)+w(j,k) (i≤j≤k≤l)w(i,k)+w(j,l)≤w(i,l)+w(j,k)\ (i≤j≤k≤l)w(i,k)+w(j,l)≤w(i,l)+w(j,k) (i≤j≤k≤l) ,因此该 DPDPDP 的决策点满足决策单调性。
- 那么利用决策调性进行分治,这个问题在序列上的形式就有一种简单的 O(NKLogN)O(NKLogN)O(NKLogN) 的做法。
- 另外,
打表注意到令 opt(x)opt(x)opt(x) 为 k=xk=xk=x 时的最优解,那么 opt(x)opt(x)opt(x) 是一个关于 xxx 的下凸函数,因此,通过二分斜率凸优化,我们可以在 O(NLogNLogV)O(NLogNLogV)O(NLogNLogV) 的时间内解决这个问题在序列上的形式。- 这里涉及到了一个凸优化输出方案的小技巧,设当前二分出的斜率 xxx 的最优解分段至多分成 k1k_1k1 段,至少分成 k2k_2k2 段,且 k1<k<k2k_1<k<k_2k1<k<k2 ,那么若存在一组位置 i≤j≤k≤li≤j≤k≤li≤j≤k≤l ,其中 i,li,li,l 为最少分段中相邻的两个断点, j,kj,kj,k 为最多分段中相邻的两个断点,由 w(i,k)+w(j,l)≤w(i,l)+w(j,k) (i≤j≤k≤l)w(i,k)+w(j,l)≤w(i,l)+w(j,k)\ (i≤j≤k≤l)w(i,k)+w(j,l)≤w(i,l)+w(j,k) (i≤j≤k≤l) ,并且原有的两个分段均为最优解,我们取最少分段中 iii 以及其之前的断点、取最多分段中 kkk 以及其之后的断点,形成的解一定也是一个最优解。可以证明,我们一定通过这种方式可以找到一种调整的方式将分的段数调整至 kkk 。
- 考虑环上的问题,假设全局最优解为 (p0,p1,p2,...,pk)(p_0,p_1,p_2,...,p_k)(p0,p1,p2,...,pk) 那么显然 (p1,p2,p3,...,pk,p0)(p_1,p_2,p_3,...,p_k,p_0)(p1,p2,p3,...,pk,p0) 也是全局最优解,由决策单调性,对于任意 p0≤q0≤p1p_0≤q_0≤p_1p0≤q0≤p1 ,以 q0q_0q0 开头的最优解 (q0,q1,q2,...,qk)(q_0,q_1,q_2,...,q_k)(q0,q1,q2,...,qk) 一定满足 p0≤q0≤p1,p1≤q1≤p2,...,pk≤qk≤p0p_0≤q_0≤p_1,p_1≤q_1≤p_2,...,p_k≤q_k≤p_0p0≤q0≤p1,p1≤q1≤p2,...,pk≤qk≤p0 。
- 不妨令 p0≤0≤p1p_0≤0≤p_1p0≤0≤p1 用上述凸优化的方式求出以 000 开头的最优解 (0,q1,q2,...,qk)(0,q_1,q_2,...,q_k)(0,q1,q2,...,qk) ,那么 0−q1,q1−q2,...,qk−1−qk0-q_1,q_1-q_2,...,q_{k-1}-q_k0−q1,q1−q2,...,qk−1−qk 中的每一段都会有一个最优解上的断点,选择其中最小的一段,其长度必定在 O(Nk)O(\frac{N}{k})O(kN) 内。不妨令选择了 0−q10-q_10−q1 ,我们需要求出以其中每一个点开始的最优解,这些最优解中一定包含了全局最优解。
- 注意到我们已经确定了每一个决策点的范围,用最开始提到的分治做法求一个点开始的最优解是 O(NLogN)O(NLogN)O(NLogN) 的。
- 假设我们当前需要求 [l,r][l,r][l,r] 中每一个点开始的最优解,我们可以先求出 mid=l+r2mid=\frac{l+r}{2}mid=2l+r 开始的最优解,并且,我们将进一步确定 [l,mid−1][l,mid-1][l,mid−1] 和 [mid+1,r][mid+1,r][mid+1,r] 中的决策点的范围,如此递归处理,时间复杂度为 O(NLog2N)O(NLog^2N)O(NLog2N) 。
- 时间复杂度 O(NLog2N+NLogNLogV)O(NLog^2N+NLogNLogV)O(NLog2N+NLogNLogV) 。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 4e5 + 5; const long long INF = 1e18; typedef long long ll; typedef long double ld; typedef unsigned long long ull; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } struct info {ll val; int cnt; }; bool operator < (info a, info b) { if (a.val == b.val) return a.cnt < b.cnt; else return a.val < b.val; } bool operator > (info a, info b) { if (a.val == b.val) return a.cnt > b.cnt; else return a.val < b.val; } bool operator <= (info a, info b) { if (a.val == b.val) return a.cnt <= b.cnt; else return a.val < b.val; } bool operator >= (info a, info b) { if (a.val == b.val) return a.cnt >= b.cnt; else return a.val < b.val; } info operator + (info a, ll val) { a.val += val; a.cnt += 1; return a; } struct range { int pos; int l, r; }; int n, k, pos[MAXN], home; ll ans, anspos[MAXN]; ll a[MAXN], sum[MAXN], l; ll weight(int l, int r) { int mid = (l + r + 1) / 2; return (mid - l) * a[mid] - (sum[mid] - sum[l]) + (sum[r] - sum[mid]) - (r - mid) * a[mid]; } info least[MAXN], most[MAXN]; int pathl[MAXN], pathm[MAXN]; void work(int from, ll cost) { least[from] = most[from] = (info) {0, 0}; static range qleast[MAXN], qmost[MAXN]; int lleast = 1, rleast = 1, lmost = 1, rmost = 1; qleast[1] = qmost[1] = (range) {from, from + 1, from + n}; for (int i = from + 1; i <= from + n; i++) { pathl[i] = qleast[lleast].pos, pathm[i] = qmost[lmost].pos; most[i] = most[qmost[lmost].pos] + (weight(qmost[lmost].pos, i) + cost); least[i] = least[qleast[lleast].pos] + (weight(qleast[lleast].pos, i) + cost); assert(most[i].val == least[i].val); if (i == qmost[lmost].r) lmost++; else qmost[lmost].l++; if (i == qleast[lleast].r) lleast++; else qleast[lleast].l++; if (i == from + n) break; while (lmost <= rmost && most[i] + weight(i, qmost[rmost].l) >= most[qmost[rmost].pos] + weight(qmost[rmost].pos, qmost[rmost].l)) { rmost--; if (rmost >= lmost) qmost[rmost].r = qmost[rmost + 1].r; } if (rmost < lmost) qmost[++rmost] = (range) {i, i + 1, from + n}; else if (most[i] + weight(i, qmost[rmost].r) >= most[qmost[rmost].pos] + weight(qmost[rmost].pos, qmost[rmost].r)) { int l = qmost[rmost].l, r = qmost[rmost].r; while (l < r) { int mid = (l + r) / 2; if (most[i] + weight(i, mid) >= most[qmost[rmost].pos] + weight(qmost[rmost].pos, mid)) r = mid; else l = mid + 1; } qmost[rmost].r = l - 1; qmost[++rmost] = (range) {i, l, from + n}; } while (lleast <= rleast && least[i] + weight(i, qleast[rleast].l) <= least[qleast[rleast].pos] + weight(qleast[rleast].pos, qleast[rleast].l)) { rleast--; if (rleast >= lleast) qleast[rleast].r = qleast[rleast + 1].r; } if (rleast < lleast) qleast[++rleast] = (range) {i, i + 1, from + n}; else if (least[i] + weight(i, qleast[rleast].r) <= least[qleast[rleast].pos] + weight(qleast[rleast].pos, qleast[rleast].r)) { int l = qleast[rleast].l, r = qleast[rleast].r; while (l < r) { int mid = (l + r) / 2; if (least[i] + weight(i, mid) <= least[qleast[rleast].pos] + weight(qleast[rleast].pos, mid)) r = mid; else l = mid + 1; } qleast[rleast].r = l - 1; qleast[++rleast] = (range) {i, l, from + n}; } } } ll calc(int from) { ll l = 0, r = INF; while (l <= r) { ll mid = (l + r) / 2; work(from, mid); if (least[from + n].cnt <= k && most[from + n].cnt >= k) { static int posl[MAXN], posm[MAXN]; for (int i = least[from + n].cnt, pos = from + n; i >= 0; i--) posl[i] = pos, pos = pathl[pos]; for (int i = most[from + n].cnt, pos = from + n; i >= 0; i--) posm[i] = pos, pos = pathm[pos]; for (int i = 0; i < least[from + n].cnt; i++) { int tmp = most[from + n].cnt - k + i; pos[i] = posl[i]; if (posl[i] <= posm[tmp] && posl[i + 1] >= posm[tmp + 1]) { int now = i; for (int j = tmp + 1; j <= most[from + n].cnt; j++) pos[++now] = posm[j]; return least[from + n].val - k * mid; } } assert(false); } if (least[from + n].cnt > k) l = mid + 1; else r = mid - 1; } return -1; } map <int, ll> dp[MAXN]; map <int, int> path[MAXN]; vector <int> rangel, ranger, dppos; void conquer(int layer, int l, int r, int ql, int qr) { if (l > r) return; int mid = (l + r) / 2; dp[mid][layer] = INF; for (int i = ql; i <= qr; i++) { ll tmp = dp[i][layer - 1] + weight(i, mid); if (tmp < dp[mid][layer]) dp[mid][layer] = tmp, path[mid][layer] = i; } assert(dp[mid][layer] != INF); conquer(layer, l, mid - 1, ql, path[mid][layer]); conquer(layer, mid + 1, r, path[mid][layer], qr); } ll getdp(int from) { dp[from][0] = 0; path[from][0] = from; int lastl = from, lastr = from; for (int i = 1; i <= k; i++) { conquer(i, rangel[i], ranger[i], lastl, lastr); lastl = rangel[i], lastr = ranger[i]; } dppos.resize(k + 1); for (int i = k, pos = from + n; i >= 0; i--) dppos[i] = pos, pos = path[pos][i]; return dp[from + n][k]; } void divide(int l, int r) { if (l > r) return; vector <int> bakl = rangel, bakr = ranger; int mid = (l + r) / 2; ll tmp = getdp(mid); vector <int> bak = dppos; if (tmp < ans) ans = tmp, home = mid; rangel = bakl, ranger = bak; divide(l, mid - 1); rangel = bak, ranger = bakr; divide(mid + 1, r); rangel = bakl, ranger = bakr; } int main() { read(n), read(k), read(l); for (int i = 1; i <= n; i++) read(a[i]), a[i + n] = a[i] + l; for (int i = 1; i <= 2 * n; i++) sum[i] = sum[i - 1] + a[i]; ans = calc(0), home = 0; int Min = 1; for (int i = 1; i <= k; i++) if (pos[i] - pos[i - 1] < pos[Min] - pos[Min - 1]) Min = i; for (int i = Min; i <= k; i++) { rangel.push_back(pos[i - 1]); ranger.push_back(pos[i]); } for (int i = 1; i <= Min; i++) { rangel.push_back(pos[i - 1] + n); ranger.push_back(pos[i] + n); } divide(rangel[0], ranger[0]); writeln(ans = calc(home)); for (int j = 1; j <= k; j++) anspos[j] = a[(pos[j] + pos[j - 1] + 1) / 2] % l; sort(anspos + 1, anspos + k + 1); for (int i = 1; i <= k; i++) printf("%lld ", anspos[i]); return 0; }