题意:
三个有n个正整数的数组满足a[i]<b[i]<c[i],固定一个r,根据不同的x,y生成不同的数列,每对合法的x,y(x<y)生成两个区间[x,x+r-1],[y,y+r-1],合法指生成的区间不会越界,如果一个点i被两个区间覆盖,他的值就是c[i],如果被一个覆盖就是b[i],否则a[i],一个数列的价值等于所有点的值之和,问在所有x,y生成的序列中价值第k大的是多少
分析:
首先把b[i],c[i]减去a[i],最后再加回来方便计算
这种第k大的经典套路就是二分答案check有多少个值比他小
在这个题的check里,我们分两部分讨论:x和y的区间有重叠的和没有重叠的,对于没有重叠的十分容易,枚举每个点作为右区间的起始端点再离散化区间b[i]的和用树状数组统计与多少个前面的区间值相加<mid即可
对于重叠的,表面上看重叠区间之间会互相影响,无法像不重叠那样直接将区间b[i]和插入树状数组
这就需要将加入树状数组的值凑成一个只与当前点相关的值,具体凑法如下:
#include<iostream>
#include<cstdio>
#include<vector>
#include<algorithm>
#include<cstring>
#define sc(x) scanf("%d", &x)
#define pb push_back
#define ALL(x) x.begin(),x.end()
using namespace std;
typedef long long LL;
const int maxn = 3e4+10;
int n, r, a[maxn], b[maxn], c[maxn];
LL bit[maxn*3], k, preb[maxn], prec[maxn];
vector<LL> val;
inline int lowbit(int x){return x&-x;}
void add(int x, int v){
for(int i = x; i < maxn*3; i += lowbit(i))
bit[i] += v;
}
LL sum(int x){
LL res = 0;
for(int i = x; i > 0; i -= lowbit(i))
res += bit[i];
return res;
}
LL check(LL mid){
//printf("check(%lld):",mid);
memset(bit, 0, sizeof(bit));
LL res = 0;
for(int i = r+1, j = 1; i+r-1 <= n; i++, j++){
int id = lower_bound(ALL(val),preb[j+r-1]-preb[j-1])-val.begin()+1;
add(id,1);
LL tmp = preb[i+r-1]-preb[i-1];
tmp = mid-tmp;
id = lower_bound(ALL(val),tmp)-val.begin();
res += sum(id);
}
//printf("%lld,", res);
memset(bit, 0, sizeof(bit));
for(int i = 2; i+r-1 <= n; i++){
LL tmp = prec[i+r-2] - preb[i+r-2] - preb[i-2];
int id = lower_bound(ALL(val),tmp)-val.begin()+1;
add(id,1);
if(i-r >= 1){
tmp = prec[i-r+r-1] - preb[i-r+r-1] - preb[i-r-1];
id = lower_bound(ALL(val),tmp)-val.begin()+1;
add(id, -1);
}
tmp = preb[i+r-1] + preb[i-1] - prec[i-1];
tmp = mid-tmp;
id = lower_bound(ALL(val),tmp)-val.begin();
res += sum(id);
}
//printf("%lld\n", res);
return res;
}
int main(){
sc(n); sc(r); cin >> k; k--;
LL sum = 0;
for(int i = 1; i <= n; i++) sc(a[i]), sum += a[i];
for(int i = 1; i <= n; i++) sc(b[i]), b[i] -= a[i];
for(int i = 1; i <= n; i++) sc(c[i]), c[i] -= a[i];
for(int i = 1; i <= n; i++){
preb[i] = preb[i-1] + b[i];
prec[i] = prec[i-1] + c[i];
if(i >= r){
val.pb(prec[i]-preb[i-r]-preb[i]);
val.pb(preb[i]-preb[i-r]);
}
}
sort(ALL(val));
val.erase(unique(ALL(val)), val.end());
LL L = 0, R = 3e10+10;
while(L+1 < R){
LL mid = (L+R)/2;
//printf("check(%lld):%lld, k:%lld\n",mid,check(mid),k);
if(check(mid) > k) R = mid;
else L = mid;
}
cout << L+sum << endl;
return 0;
}