题目大意
给定数列
{an}
{
a
n
}
,要求维护以下操作和询问:
- 将 ai a i 赋值为 val v a l
- 在区间 [l,r] [ l , r ] 中选出最多 k k 个互不相交的子段列,最大化这些选中的数的和,输出这个最大值
操作和询问共个
分析:
首先看一下暴力怎么解决这个问题:
把一个数拆成两个点,作为
X
X
部和部
X
X
部的点向部的对应点连边,容量为1,费用为
ai
a
i
原点向
X
X
部连边,容量为1,费用为0
部向汇点连边,容量为1,费用为0
相邻点从
Y
Y
部连向部,容量为1,费用为0
跑
k
k
次流量为1的最大费用流即可
上述方法边数有,显然会TLE
那么我们就需要看出算法的实质:
每次增广的过程,实质上就是取一段和最大的子序列,并将其反转
很显然对于这类序列上的操作,可以用线段树去实现
正解
费用流的构图,线段树手动模拟增广过程
线段树维护方法:
维护一段区间的最大子序列
每次我们提取出一个最大子序列时,我们要把这个子序列取反(*-1,防止重复选择),所以还需要维护最小子序列
每进行一次取反,当前最大和子序列一定变成最小和子序列,最小和子序列一定变成最大,那么直接swap一下就可以了
鉴于一次询问需要增广K次,每一次都要要取反,所以需要开一个栈记录一下当前询问所反转的所有区间,在结束时还原
总的时间复杂度是 O(knlogn) O ( k n l o g n )
看一下维护:
struct node{
int lx,rx,mx,sum;
int lp,rp,p1,p2;
void init(int l,int val) {
lp=rp=p1=p2=l;
lx=rx=mx=val;
sum=val;
}
};
struct Tree{
int l,r,a,b;
bool flag;
node mn,mx;
void init(int val) {
mn.init(l,-val);
mx.init(l,val);
}
};
Tree t[N<<2];
lx
l
x
:从左端点开始的最大子序列
rx
r
x
:从右端点开始的最大子序列
mx
m
x
:整个区间的最大子序列
sum
s
u
m
:区间和
lp
l
p
:
lx
l
x
的右端点
rp
r
p
:
rx
r
x
的左端点
p1
p
1
:
mx
m
x
的左端点
p2
p
2
:
mx
m
x
的右端点
flag
f
l
a
g
:区间翻转标记
mn
m
n
:记录区间的最大子序列
mx
m
x
:记录区间的最小子序列
init i n i t 函数:插入一个值(mx正值,mn负值)
void push(int bh) {
if (t[bh].l==t[bh].r) return;
if (t[bh].flag) {
swap(t[bh<<1].mx,t[bh<<1].mn);
swap(t[bh<<1|1].mx,t[bh<<1|1].mn);
t[bh<<1].flag^=1; t[bh<<1|1].flag^=1;
t[bh].flag^=1;
}
}
处理区间翻转: mn<=>mx m n <=> m x
node merge(node a,node b) {
node t;
t.sum=a.sum+b.sum;
t.lx=a.lx; t.lp=a.lp;
if (a.sum+b.lx>t.lx) t.lx=a.sum+b.lx,t.lp=b.lp;
t.rx=b.rx; t.rp=b.rp;
if (b.sum+a.rx>t.rx) t.rx=b.sum+a.rx,t.rp=a.rp;
t.mx=a.rx+b.lx;
t.p1=a.rp; t.p2=b.lp;
if (t.mx<a.mx) t.mx=a.mx,t.p1=a.p1,t.p2=a.p2;
if (t.mx<b.mx) t.mx=b.mx,t.p1=b.p1,t.p2=b.p2;
return t;
}
重要的合并函数
按照各变量的意义转移即可
void solve(int l,int r,int k) {
int ans=0;
top=0;
while (k--) {
node t=ask(1,l,r);
if (t.mx>0) ans+=t.mx;
else break;
rever(1,t.p1,t.p2);
++top;
q[top].x=t.p1; q[top].y=t.p2;
}
for (int i=top;i>0;i--)
rever(1,q[i].x,q[i].y); //消除影响
printf("%d\n",ans);
}
每次我们提取出一个最大子序列
t
t
,加入答案(如果小于0就停止操作)
q是记录翻转区间的栈,处理完之后消除翻转影响
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<queue>
#define ll long long
using namespace std;
const int N=100005;
struct node{
int lx,rx,mx,sum;
int lp,rp,p1,p2;
void init(int l,int val) {
lp=rp=p1=p2=l;
lx=rx=mx=val;
sum=val;
}
};
struct Tree{
int l,r,a,b;
bool flag;
node mn,mx;
void init(int val) {
mn.init(l,-val);
mx.init(l,val);
}
};
Tree t[N<<2];
struct point{
int x,y;
};
point q[N];
int n,m,a[N],top=0;
void push(int bh) {
if (t[bh].l==t[bh].r) return;
if (t[bh].flag) {
swap(t[bh<<1].mx,t[bh<<1].mn);
swap(t[bh<<1|1].mx,t[bh<<1|1].mn);
t[bh<<1].flag^=1; t[bh<<1|1].flag^=1;
t[bh].flag^=1;
}
}
node merge(node a,node b) {
node t;
t.sum=a.sum+b.sum;
t.lx=a.lx; t.lp=a.lp;
if (a.sum+b.lx>t.lx) t.lx=a.sum+b.lx,t.lp=b.lp;
t.rx=b.rx; t.rp=b.rp;
if (b.sum+a.rx>t.rx) t.rx=b.sum+a.rx,t.rp=a.rp;
t.mx=a.rx+b.lx;
t.p1=a.rp; t.p2=b.lp;
if (t.mx<a.mx) t.mx=a.mx,t.p1=a.p1,t.p2=a.p2;
if (t.mx<b.mx) t.mx=b.mx,t.p1=b.p1,t.p2=b.p2;
return t;
}
void update(int bh) {
t[bh].mn=merge(t[bh<<1].mn,t[bh<<1|1].mn);
t[bh].mx=merge(t[bh<<1].mx,t[bh<<1|1].mx);
}
void build(int bh,int l,int r) {
t[bh].l=l; t[bh].r=r;
if (l==r) {
t[bh].init(a[l]);
return;
}
int mid=(l+r)>>1;
build(bh<<1,l,mid);
build(bh<<1|1,mid+1,r);
update(bh);
}
void rever(int bh,int L,int R) {
push(bh);
int l=t[bh].l,r=t[bh].r,mid=(l+r)>>1;
if (l>=L&&r<=R) {
swap(t[bh].mn,t[bh].mx);
t[bh].flag^=1;
return;
}
if (L<=mid) rever(bh<<1,L,R);
if (R>mid) rever(bh<<1|1,L,R);
update(bh);
}
node ask(int bh,int L,int R) {
push(bh);
int l=t[bh].l,r=t[bh].r,mid=(l+r)>>1;
if (l==L&&r==R) return t[bh].mx;
if (R<=mid) return ask(bh<<1,L,R);
else if (L>mid) return ask(bh<<1|1,L,R);
else return merge(ask(bh<<1,L,mid),ask(bh<<1|1,mid+1,R));
}
void solve(int l,int r,int k) {
int ans=0;
top=0;
while (k--) {
node t=ask(1,l,r);
if (t.mx>0) ans+=t.mx;
else break;
rever(1,t.p1,t.p2);
++top;
q[top].x=t.p1; q[top].y=t.p2;
}
for (int i=top;i>0;i--)
rever(1,q[i].x,q[i].y); //消除影响
printf("%d\n",ans);
}
void change(int bh,int pos,int val) {
push(bh);
int l=t[bh].l,r=t[bh].r,mid=(l+r)>>1;
if (l==r) {
t[bh].init(val); return;
}
if (pos<=mid) change(bh<<1,pos,val);
else change(bh<<1|1,pos,val);
update(bh);
}
int main()
{
scanf("%d",&n);
for (int i=1;i<=n;i++) scanf("%d",&a[i]);
build(1,1,n);
scanf("%d",&m);
int opt,l,r;
while (m--) {
scanf("%d%d%d",&opt,&l,&r);
if (opt==1) {
int x; scanf("%d",&x);
solve(l,r,x);
}
else change(1,l,r);
}
return 0;
}