思路:
**p[i] a[i]~a[n]最大的结果
**cnt[i] a[1]~a[i]最大的结果
**mx[i] a[1]~a[i]最大值
预处理出p[i],cnt[i],mx[i]。
对于每次查询,在修改a[x]为y,如果1~x的最大值大于等于y,s+=cnt[x-1],否则s+=cnt[x-1]+1;
然后对于后半区间,如果y>=mx[x-1],s+=(x+1~n)中第一个大于y处的cnt[i],否则加上(x+1~n)中第一个大于mx[x-1]处的cnt[i]
对于如何找到i~n区间内第一个大于等于x的数的位置,我们可以先将1~i中的每个数tag标记为1,然后pushup的时候处理一下,使1~i中的最大值为0,这样就可以很方便地求出。然后消掉a[1]~a[i-1]的标记即可。具体请看代码
AC代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
#define ls rt<<1
#define rs rt<<1|1
#define ll t[rt].l
#define rr t[rt].r
struct node
{
int l,r,mx,tag;
}t[maxn*4];
int n,m;
int a[maxn],p[maxn],cnt[maxn],mx[maxn];
void pushup(int rt)
{
if(t[ls].tag==1&&t[rs].tag==0) t[rt].mx=t[rs].mx;
else if(t[ls].tag==0&&t[rs].tag==1) t[rt].mx=t[ls].mx;
else if(t[ls].tag==1&&t[rs].tag==1) t[rt].mx=0;
else t[rt].mx=max(t[ls].mx,t[rs].mx);
}
void build(int rt,int l,int r)
{
t[rt].l=l;
t[rt].r=r;
t[rt].tag=0;
if(l==r){
t[rt].mx=a[l];
return ;
}
int mid=(l+r)>>1;
build(ls,l,mid);
build(rs,mid+1,r);
pushup(rt);
}
void update(int rt,int l,int r,int v)
{
if(l<=ll&&rr<=r){
t[rt].tag=v;
return ;
}
int mid=(ll+rr)>>1;
if(l<=mid) update(ls,l,r,v);
if(r>mid) update(rs,l,r,v);
pushup(rt);
}
int query(int rt,int l,int r,int val)
{
if(r<l) return 0;
if(t[rt].mx<val) return 0;
if(ll==rr) return ll;
int mid=(ll+rr)>>1;
if(t[ls].mx>=val&&l<=mid) return query(ls,l,r,val);
if(t[rs].mx>=val&&r>mid) return query(rs,l,r,val);
return 0;
}
stack<int> sta;
int main()
{
int T;
scanf("%d",&T);
while(T--){
while(sta.size())sta.pop();
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
int sz=1;
int now=a[1];
cnt[1]=1;
mx[1]=now;
for(int i=2;i<=n;i++){
if(a[i]>now){
now=a[i];
sz++;
cnt[i]=cnt[i-1]+1;
}
else cnt[i]=cnt[i-1];
mx[i]=now;
}
p[n]=1;
sta.push(a[n]);
for(int i=n-1;i>=1;i--){
while(sta.size()&&sta.top()<=a[i]) sta.pop();
sta.push(a[i]);
p[i]=sta.size();
}
int x,y;
build(1,1,n);
while(m--){
scanf("%d%d",&x,&y);
int s=0;
if(mx[x-1]>=y) s=cnt[x-1];
else s=cnt[x-1]+1;
int pos;
update(1,1,x,1);
if(y>=mx[x-1]) pos=query(1,x+1,n,y+1);
else pos=query(1,x+1,n,mx[x-1]+1);
update(1,1,x,0);
if(pos) s+=p[pos];
printf("%d\n",s);
}
}
return 0;
}