题意:给你一个长度为n的数组a和长度为m的数组b 将数组a和数组b保留相对位置的基础上拼成数组c 求
的最小值
题解:首先有一个贪心的小思路,优先选择大的,但是因为选择的时候有一定的限制性,a数组前面的元素一定比后面的元素先选,那么你若按照这个思路直接去选,肯定是不对的(随便找都是反例),其实我们可以去选择后面的大元素与前面的元素合并,然后优先选择平均数最大的元素块,(初始认为所有元素都是大小为1的元素块)若后面的元素块的平均数大于前面的元素块,那么便把两个元素块合并,合并完成后,a和b数组都会变成单调递减的,然后直接贪心即可
#include<bits/stdc++.h>
using namespace std;
#define Sheryang main
const int maxn=2e5+7;
typedef long long ll;
const int mod=1e9+7;
///#define getchar()(p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
///char buf[(1 << 21) + 1], *p1 = buf, *p2 = buf;
#define IO cin.tie(0),ios::sync_with_stdio(false);
#define pi acos(-1)
#define PII pair<ll,ll>
ll read(){ll c = getchar(),Nig = 1,x = 0;while(!isdigit(c) && c!='-')c = getchar();if(c == '-')Nig = -1,c = getchar();while(isdigit(c))x = ((x<<1) + (x<<3)) + (c^'0'),c = getchar();return Nig*x;}
#define read read()
/** keep hungry and keep calm! **/
struct node{
double sum;
int cnt;
friend bool operator <(node a,node b){
return a.sum*b.cnt < a.cnt*b.sum;
}
friend node operator +(node a,node b){
return {a.sum+b.sum,a.cnt+b.cnt};
}
};
ll a[maxn],b[maxn];
node S[maxn],T[maxn];
vector<int> tmp;
void add(int op,int l,int r){
for(int i=l;i<=r;i++){
if(op == 1){
tmp.push_back(a[i]);
}else{
tmp.push_back(b[i]);
}
}
}
int Sheryang(){
int TT=read;
for(int cas=1;cas<=TT;cas++){
int n=read,m=read;
for(int i=1;i<=n;i++){
a[i] = read;
}
for(int i=1;i<=m;i++){
b[i] = read;
}
printf("Case %d: ",cas);
int cnt1 = 0;
for(int i=1;i<=n;i++){
S[++cnt1] = {a[i]*1.0,1};
while(cnt1 > 1 && S[cnt1-1] < S[cnt1]){
S[cnt1-1] = S[cnt1-1] + S[cnt1];
cnt1 --;
}
}
S[++cnt1] = {-1000,1};
int cnt2 = 0;
for(int i=1;i<=m;i++){
T[++cnt2] = {b[i]*1.0,1};
while(cnt2 > 1 && T[cnt2-1] < T[cnt2]){
T[cnt2-1] = T[cnt2-1] + T[cnt2];
cnt2 --;
}
}
T[++cnt2] = {-1000,1};
tmp.clear();
int pos1 = 1 ,pos2 = 1 , fa = 1 , fb = 1;
while(pos1 < cnt1 || pos2 < cnt2){
if(S[pos1]<T[pos2]){
add(2,fb,fb + T[pos2].cnt - 1);
fb = fb + T[pos2].cnt ;
pos2 ++;
}else{
add(1,fa,fa + S[pos1].cnt - 1);
fa = fa + S[pos1].cnt ;
pos1 ++;
}
}
ll ans = 0;
for(int i=0;i<tmp.size();i++){
ans += 1LL*(i+1)*tmp[i];
}
printf("%lld\n",ans);
}
return 0;
}