题目大意:
给你一个长度为n(1<=n<=105)n(1<=n<=105)的序列,支持两种操作:
1.修改某个位置的值。
2.在区间[l,r][l,r]里选择不超过k(1<=k<=20)k(1<=k<=20)个不相交的子区间和的最大值。
操作数m<=105m<=105。
分析:
因为kk比较小,可以考虑使用费用流的模型。源点向每个点ii连一条流量为,费用为00的边;向i+1i+1连一条费用为aiai,流量为11的边;向汇点TT连流量为,费用为00的边。那么,一滴从的流,代表选取区间[i,j−1][i,j−1]。可以发现的是,每多流一滴流量,就多一个区间。问题就变成了使用不多于kk滴流量,能得到的最大费用。因为每次增广只会多一点流量,而每次要查询的最长路就是每次增广的贡献,这个贡献是由一段区间构成。
所以,每次增广相当于找到询问区间的最大子区间和,然后再把该区间取反(反向弧的费用是正向弧的相反数)。举个例子,如果第一次选了区间,然后取反后取了区间[3,6][3,6],此时相当于取了区间[1,2][1,2]和[5,6][5,6]。增广完kk次,或者增广贡献小于时,就得到了答案。
因为不可能同时选共起点或共终点的区间,假如我们第一次取了[1,4][1,4],那么一定不可能取反第二次后会取[1,5][1,5],这样得到的区间是[4,5][4,5],因为我们保证每次增广的贡献大于等于00,这个区间是比大的,与一开始就取最大子区间矛盾。所以每次增广必然会增加一个区间。
现在就可以用线段树维护了,相当于维护一个最大子区间值,当然,同时肯定要维护最大左区间和与最大右区间和区间和,单点修改就不用lazylazy了,不然真的太毒瘤了。当然,由于有反转操作,还需要维护最小的子区间和,反转时swapswap一下就好,打上反转标记。
定义一个区间类型及使用重载运算符来处理区间加可以减少代码量。
代码:
#include <iostream>
#include <cstdio>
#include <cmath>
#include <queue>
const int maxn=1e5+7;
using namespace std;
int n,m,x,op,y,k;
struct rec{
int l,r,s;
};
rec operator + (rec x,rec y)
{
return (rec){x.l,y.r,x.s+y.s};
}
bool operator < (rec x,rec y)
{
return x.s<y.s;
}
struct node{
rec smax,smin,lmax,rmax,lmin,rmin,sum;
int rev;
}t[maxn*4];
queue <node> q;
void neww(int p,int l,int k)
{
t[p].smax=(rec){l,l,k};
t[p].lmax=(rec){l,l,k};
t[p].rmax=(rec){l,l,k};
t[p].smin=(rec){l,l,k};
t[p].lmin=(rec){l,l,k};
t[p].rmin=(rec){l,l,k};
t[p].sum=(rec){l,l,k};
}
node merge(node x,node y)
{
node z;
z.smax=max(x.smax,y.smax);
z.smax=max(z.smax,x.rmax+y.lmax);
z.smin=min(x.smin,y.smin);
z.smin=min(z.smin,x.rmin+y.lmin);
z.lmax=max(x.lmax,x.sum+y.lmax);
z.rmax=max(y.rmax,x.rmax+y.sum);
z.lmin=min(x.lmin,x.sum+y.lmin);
z.rmin=min(y.rmin,x.rmin+y.sum);
z.sum=x.sum+y.sum;
z.rev=0;
return z;
}
void clean(int p)
{
swap(t[p].smax,t[p].smin);
swap(t[p].lmax,t[p].lmin);
swap(t[p].rmax,t[p].rmin);
t[p].smax.s*=-1;
t[p].smin.s*=-1;
t[p].lmax.s*=-1;
t[p].lmin.s*=-1;
t[p].rmax.s*=-1;
t[p].rmin.s*=-1;
t[p].sum.s*=-1;
t[p].rev^=1;
}
void change(int p,int l,int r,int x,int k)
{
if (l==r)
{
neww(p,l,k);
return;
}
int mid=(l+r)/2;
if (t[p].rev)
{
clean(p*2);
clean(p*2+1);
t[p].rev^=1;
}
if (x<=mid) change(p*2,l,mid,x,k);
else change(p*2+1,mid+1,r,x,k);
t[p]=merge(t[p*2],t[p*2+1]);
}
void rev(int p,int l,int r,int x,int y)
{
if ((l==x) && (r==y))
{
clean(p);
return;
}
int mid=(l+r)/2;
if (t[p].rev)
{
clean(p*2);
clean(p*2+1);
t[p].rev^=1;
}
if (y<=mid) rev(p*2,l,mid,x,y);
else if (x>mid) rev(p*2+1,mid+1,r,x,y);
else
{
rev(p*2,l,mid,x,mid);
rev(p*2+1,mid+1,r,mid+1,y);
}
t[p]=merge(t[p*2],t[p*2+1]);
}
node query(int p,int l,int r,int x,int y)
{
if ((l==x) && (r==y)) return t[p];
int mid=(l+r)/2;
if (t[p].rev)
{
clean(p*2);
clean(p*2+1);
t[p].rev^=1;
}
if (y<=mid) return query(p*2,l,mid,x,y);
else if (x>mid) return query(p*2+1,mid+1,r,x,y);
else
{
return merge(query(p*2,l,mid,x,mid),query(p*2+1,mid+1,r,mid+1,y));
}
}
int main()
{
scanf("%d",&n);
for (int i=1;i<=n;i++)
{
scanf("%d",&x);
change(1,1,n,i,x);
}
scanf("%d",&m);
for (int i=1;i<=m;i++)
{
scanf("%d",&op);
if (op==0)
{
scanf("%d%d",&x,&k);
change(1,1,n,x,k);
}
else
{
scanf("%d%d%d",&x,&y,&k);
int ans=0;
while (k--)
{
node d=query(1,1,n,x,y);
if (d.smax.s<0) break;
ans+=d.smax.s;
rev(1,1,n,d.smax.l,d.smax.r);
q.push(d);
}
printf("%d\n",ans);
while (!q.empty())
{
node d=q.front();
q.pop();
rev(1,1,n,d.smax.l,d.smax.r);
}
}
}
}