题目大意
有一个长为nnn的序列aia_iai,你可以选择一个iii花费cxc_xcx元(x∈[1,m])(x\in[1,m])(x∈[1,m])将aia_iai变为⌊aix⌋\lfloor\dfrac{a_i}{x}\rfloor⌊xai⌋,你总共有KKK元,求最终序列的中位数最小是多少。保证nnn为奇数。
1≤n,m≤5×105,1≤ai≤m,1≤ci,K≤1091\leq n,m\leq 5\times 10^5,1\leq a_i\leq m,1\leq c_i,K\leq 10^91≤n,m≤5×105,1≤ai≤m,1≤ci,K≤109
题解
首先,我们知道⌊⌊tx⌋y⌋=⌊txy⌋\lfloor\dfrac{\lfloor\frac tx\rfloor}{y}\rfloor=\lfloor\dfrac{t}{xy}\rfloor⌊y⌊xt⌋⌋=⌊xyt⌋,那么我们可以用ci×j=min(ci×j,ci+cj)c_{i\times j}=\min(c_{i\times j},c_i+c_j)ci×j=min(ci×j,ci+cj)来更新所有ccc值,这样我们就可以得到用若干次除法将aia_iai除以一个数的最小代价。更新所有ccc值的时间复杂度为O(∑i=1mmi)=O(mlnm)O(\sum\limits_{i=1}^m\dfrac mi)=O(m\ln m)O(i=1∑mim)=O(mlnm)。
二分答案midmidmid,我们需要让⌈n2⌉\lceil\dfrac n2\rceil⌈2n⌉个aia_iai都小于等于midmidmid。对于每个aia_iai,我们可以二分求出最小的x0x_0x0使得⌊aix0⌋≤mid\lfloor\dfrac{a_i}{x_0}\rfloor\leq mid⌊x0ai⌋≤mid,那么我们取x≥x0x\geq x_0x≥x0即可使得⌊aix⌋≤mid\lfloor\dfrac{a_i}{x}\rfloor\leq mid⌊xai⌋≤mid,我们在大于等于x0x_0x0的xxx中取cxc_xcx的最小值,预处理一个后缀最小值即可。
求出将每个aia_iai降到≤mid\leq mid≤mid的最小代价wiw_iwi后,将wiw_iwi从小到大排序,然后取代价最小的⌈n2⌉\lceil\dfrac n2\rceil⌈2n⌉个wiw_iwi,,记这些wiw_iwi的和为sumsumsum。如果sum≤Ksum\leq Ksum≤K,则midmidmid合法;否则,midmidmid不合法。
这样做的话,时间复杂度为O(nlogmlogn)O(n\log m\log n)O(nlogmlogn)。我们考虑优化。
我们可以先将aia_iai从小到大排序,那么随着aia_iai的增加,x0x_0x0也是单调不降的,那么我们用一个双指针就可以O(n)O(n)O(n)求出所有aia_iai对应的x0x_0x0,并求出将这个aia_iai降到≤mid\leq mid≤mid的最小代价wiw_iwi。我们发现,此时wiw_iwi也是单调不降的,于是就不用排序了。一次判断midmidmid是否合法的时间复杂度为O(n)O(n)O(n),那么总时间复杂度为O(nlogm)O(n\log m)O(nlogm)。
因为中位数可能为000,所以需要求除数大于mmm的情况,可以将这类情况的最小代价存储在cm+1c_{m+1}cm+1中。具体操作见代码。
总时间复杂度为O(mlnm+nlogm)O(m\ln m+n\log m)O(mlnm+nlogm)。
code
#include<bits/stdc++.h>
using namespace std;
const int N=500000;
int n,m,k,a[N+5],c[N+5],s[N+5],w[N+5];
bool check(int mid){
int x=1;
for(int i=1;i<=n;i++){
if(a[i]<=mid){
w[i]=0;continue;
}
while(a[i]/x>mid) ++x;
w[i]=s[x];
}
int tmp=0;
for(int i=1;i<=n/2+1;i++){
tmp+=w[i];
if(tmp>k) return 0;
}
return 1;
}
int main()
{
// freopen("opt.in","r",stdin);
// freopen("opt.out","w",stdout);
scanf("%d%d%d",&n,&m,&k);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
}
for(int i=1;i<=m;i++){
scanf("%d",&c[i]);
}
sort(a+1,a+n+1);
for(int i=1;i<=m;i++){
for(int j=2;i*j<=m;j++){
c[i*j]=min(c[i*j],c[i]+c[j]);
}
}
s[m]=c[m];s[m+1]=1e9+1;
for(int i=m-1;i>=1;i--) s[i]=min(s[i+1],c[i]);
for(int i=1;i<=m;i++){
for(int j=2;;j++){
if(i*j>m){
s[m+1]=min(s[m+1],s[i]+s[j]);
break;
}
}
}
for(int i=m;i>=1;i--) s[i]=min(s[i+1],s[i]);
int l=0,r=m,mid;
while(l<=r){
mid=l+r>>1;
if(check(mid)) r=mid-1;
else l=mid+1;
}
printf("%d",r+1);
return 0;
}