根据贪心,不难想到每次会把最长队伍末尾的那辆车移动到最短队伍的末尾。但由于 kkk 的存在,会导致一些冗余移动的存在。设需要挪动 CCC 辆车,则怒气值可以表示为 f(C)+kCf(C) + kCf(C)+kC,其中 f(C)f(C)f(C) 是排队所产生的怒气值,kCkCkC 为变道产生的额外怒气值。仔细分析以后,可以发现这是一个凸函数,因此考虑三分答案。
一开始想要三分需要挪车的最短长度 yyy,但是不能忽略 kkk 的影响,有些队伍的长度虽然 >y> y>y,但挪动不移动会更优。于是三分挪动车辆的数量才是最优的。
具体来说,可以枚举哪些队伍的车辆会减少/增加。若现在考虑会减少的队伍的车辆,给 aia_iai 排序后,设当前最长队伍的车辆数为 xxx,次长的为 yyy (x≠yx \neq yx=y),然后长度为 x,yx,yx,y 的队伍的数量分别为 fx,fyf_x,f_yfx,fy。若共需要移动 CCC 辆车,则有两种情况:
-
C≥(x−y)×fxC \ge (x - y) \times f_xC≥(x−y)×fx,也就是说长度为 xxx 的车可以直接变为 yyy,C←C−(x−y)×fx; fy←fx+fy; fx←0C \leftarrow C - (x - y) \times f_x;\ f_y \leftarrow f_x + f_y;\ f_x \leftarrow 0C←C−(x−y)×fx; fy←fx+fy; fx←0。
-
C<(x−y)×fxC < (x - y) \times f_xC<(x−y)×fx,此时会产生新的队伍长度,也就是 C←0; fx−⌊Cfx⌋−1←fx−⌊Cfx⌋−1+C mod fx; ←fx−⌊Cfx⌋+(fx−C mod fx)C \leftarrow 0;\ f_{x - \lfloor\frac{C}{f_x}\rfloor - 1} \leftarrow f_{x - \lfloor\frac{C}{f_x}\rfloor - 1} + C \bmod f_x;\ \leftarrow f_{x - \lfloor\frac{C}{f_x}\rfloor} + (f_x - C \bmod f_x)C←0; fx−⌊fxC⌋−1←fx−⌊fxC⌋−1+Cmodfx; ←fx−⌊fxC⌋+(fx−Cmodfx)。
可以发现最后队伍长度的种类数不会超过 n+2n + 2n+2,因此这是 O(n)O(n)O(n) 的。考虑增加的队伍的车辆同理,用 STL 来写会简单一点。但是由于多了一支 log\loglog,实测会超时:
ll tot = sum * k,res = sum,number = sum;
set <int> s;map <int,int> bg,sm;
s.insert (-1e9);
for (int i = 1;i <= n;++i) s.insert (a[i]),++bg[a[i]];
while (sum)
{
int x = *(--s.end ()),num = bg[x];s.erase (x);
int y = *(--s.end ());
if (sum >= 1ll * (x - y) * num)
{
sum -= 1ll * (x - y) * num;
bg[y] += num;bg[x] = 0;
}
else
{
bg[x] = 0;
int tmp = sum % num;
if (tmp) bg[x - sum / num - 1] += tmp;
bg[x - sum / num] += num - tmp;
sum = 0;
}
}
s.clear ();
for (auto [x,num] : bg)
if (num) s.insert (x),sm[x] = num;
s.insert (1e9);
while (res)
{
int x = *s.begin (),num = sm[x];s.erase (x);
int y = *s.begin ();
if (res >= 1ll * (y - x) * num)
{
res -= 1ll * (y - x) * num;
sm[y] += num;sm[x] = 0;
}
else
{
sm[x] = 0;
int tmp = res % num;
if (tmp) sm[x + res / num + 1] += tmp;
sm[x + res / num] += num - tmp;
res = 0;
}
}
for (auto [x,num] : sm) tot += 1ll * x * (x + 1) / 2 * num;
return tot;
};
再次思考可以发现 STL 的 log\loglog 完全是多余的,可以通过数组来替代,但需要小心清空与去重的问题。最后的 AC 代码如下,时间复杂度 O(nlogn)O(n \log n)O(nlogn):
#include <bits/stdc++.h>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 2e18
#define pii pair <int,int>
using namespace std;
const int MAX = 2e5 + 5;
const int MOD = 1e9 + 7;
inline int read ();
int a[MAX],b[MAX];
vector <int> bg (1000001,0),sm (1000001,0);
void solve ()
{
int n = read (),k = read ();ll ans = INF;
for (int i = 1;i <= n;++i) a[i] = read ();
sort (a + 1,a + 1 + n);
auto check = [&] (ll sum) -> ll
{
ll tot = sum * k,res = sum;int cnt = 0;
vector <int> p;
for (int i = 1;i <= n;++i) p.push_back (a[i]);
for (int i = 1;i <= n;++i)
{
if (!bg[a[i]]) b[++cnt] = a[i];
++bg[a[i]];
}
b[0] = -1e9;
while (sum > 0)
{
int x = b[cnt--],num = bg[x];
int y = b[cnt];
if (sum >= 1ll * (x - y) * num)
{
sum -= 1ll * (x - y) * num;
bg[y] += num;bg[x] = 0;
}
else
{
bg[x] = 0;
int tmp = sum % num;
bg[x - sum / num] += num - tmp,p.push_back (x - sum / num);
if (tmp) bg[x - sum / num - 1] += tmp,p.push_back (x - sum / num - 1);
sum = 0;
}
}
cnt = 0;
for (auto v : p)
if (bg[v]) b[++cnt] = v,sm[v] = bg[v],bg[v] = 0;
p.clear ();
for (int i = 1;i <= cnt;++i) p.push_back (b[i]);
b[++cnt] = 1e9;cnt = 1;
while (res > 0)
{
int x = b[cnt++],num = sm[x];
int y = b[cnt];
if (res >= 1ll * (y - x) * num)
{
res -= 1ll * (y - x) * num;
sm[y] += num;sm[x] = 0;
}
else
{
sm[x] = 0;
int tmp = res % num;
if (tmp) sm[x + res / num + 1] += tmp,p.push_back (x + res / num + 1);
sm[x + res / num] += num - tmp,p.push_back (x + res / num);
res = 0;
}
}
for (auto v : p) tot += 1ll * v * (v + 1) / 2 * sm[v],sm[v] = 0;
return tot;
};
ll l = 0,r = accumulate (a + 1,a + n + 1,0ll);
while (l < r)
{
ll midl = l + (r - l) / 3,midr = r - (r - l) / 3;
ll v1 = check (midl),v2 = check (midr);
ans = min (ans,min (v1,v2));
if (v1 <= v2) r = midr - 1;
else l = midl + 1;
}
printf ("%lld\n",ans);
}
int main ()
{
int t = read ();
while (t--) solve ();
return 0;
}
inline int read ()
{
int s = 0;int f = 1;
char ch = getchar ();
while ((ch < '0' || ch > '9') && ch != EOF)
{
if (ch == '-') f = -1;
ch = getchar ();
}
while (ch >= '0' && ch <= '9')
{
s = s * 10 + ch - '0';
ch = getchar ();
}
return s * f;
}