题目大意:给出长度为n的0/1序列,要求支持:区间赋值(0/1),区间取反,查询区间和、区间最长连续1个数。
线段树练手题,今天闲来无事打算水一水,然后发现这个真的挺屎..
维护赋值标记、翻转标记,注意在赋值时清空翻转标记,判断好翻转与赋值的关系,细节见代码。
在赋值时记得把相反的清空
(以为指针比数组快很多于是写到一半把数组改成指针,然而Rank前100都没挤进去…可能是因为我的线段树太屎了吧0.0
#include <cstdio>
#include <algorithm>
#define N 100005
using namespace std;
struct Node {
Node* ch[2];
int l,r,sum,maxr[2],maxl[2],maxx[2],change_mark;
bool rev_mark;
Node() {}
Node(int _l,int _r):l(_l),r(_r),sum(0),change_mark(-1),rev_mark(false) {
maxl[0]=maxl[1]=maxr[0]=maxr[1]=maxx[0]=maxx[1]=0;
}
void rev() {
rev_mark=!rev_mark;
if(change_mark!=-1) {
change_mark=1-change_mark;
rev_mark=false;
}
sum=r-l+1-sum;
swap(maxl[0],maxl[1]);
swap(maxr[0],maxr[1]);
swap(maxx[0],maxx[1]);
return ;
}
void change(int x) {
sum=(r-l+1)*x;
maxr[x]=maxl[x]=maxx[x]=r-l+1;
maxr[x^1]=maxl[x^1]=maxx[x^1]=0;
change_mark=x;
rev_mark=false;
return ;
}
void pushdown() {
if(l==r) return ;
if(rev_mark) {
ch[0]->rev();
ch[1]->rev();
rev_mark=false;
}
if(change_mark!=-1) {
ch[0]->change(change_mark);
ch[1]->change(change_mark);
change_mark=-1;
}
return ;
}
void maintain() {
sum=ch[0]->sum+ch[1]->sum;
int mid=l+r>>1;
maxx[0]=max(max(ch[0]->maxx[0],ch[1]->maxx[0]),ch[0]->maxr[0]+ch[1]->maxl[0]);
maxx[1]=max(max(ch[0]->maxx[1],ch[1]->maxx[1]),ch[0]->maxr[1]+ch[1]->maxl[1]);
if(!ch[0]->sum) maxl[0]=mid-l+1+ch[1]->maxl[0];
else maxl[0]=ch[0]->maxl[0];
if(ch[0]->sum==mid-l+1) maxl[1]=mid-l+1+ch[1]->maxl[1];
else maxl[1]=ch[0]->maxl[1];
if(!ch[1]->sum) maxr[0]=r-mid+ch[0]->maxr[0];
else maxr[0]=ch[1]->maxr[0];
if(ch[1]->sum==r-mid) maxr[1]=r-mid+ch[0]->maxr[1];
else maxr[1]=ch[1]->maxr[1];
return ;
}
void* operator new(size_t) {
static Node *C,*mempool;
if(C==mempool) mempool=(C=new Node[1<<19])+(1<<19);
return C++;
}
}*root;
int a[N];
void init(Node*&,int,int);
int query_sum(Node*,int,int);
int query_max(Node*,int,int);
int query_maxl(Node*,int,int);
int query_maxr(Node*,int,int);
void update(Node*,int,int,int);
void Reverse(Node*,int,int);
int main() {
int n,m;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",a+i);
init(root,1,n);
while(m--) {
int mode,x,y;
scanf("%d%d%d",&mode,&x,&y);
x++, y++;
if(mode==0) update(root,x,y,0);
else if(mode==1) update(root,x,y,1);
else if(mode==2) Reverse(root,x,y);
else if(mode==3) printf("%d\n",query_sum(root,x,y));
else printf("%d\n",query_max(root,x,y));
}
return 0;
}
void init(Node*& o,int l,int r) {
o=new Node(l,r);
if(l==r) {
o->sum=a[l];
o->maxl[a[l]]=o->maxr[a[l]]=o->maxx[a[l]]=1;
return ;
}
int mid=l+r>>1;
init(o->ch[0],l,mid), init(o->ch[1],mid+1,r);
o->maintain();
return ;
}
int query_sum(Node* o,int l,int r) {
if(o->l==l && o->r==r) return o->sum;
o->pushdown();
int mid=o->l+o->r>>1;
if(r<=mid) return query_sum(o->ch[0],l,r);
if(l>mid) return query_sum(o->ch[1],l,r);
return query_sum(o->ch[0],l,mid)+query_sum(o->ch[1],mid+1,r);
}
int query_maxl(Node* o,int l,int r) {
if(o->l==l && o->r==r) return o->maxl[1];
o->pushdown();
int mid=o->l+o->r>>1;
if(r<=mid) return query_maxl(o->ch[0],l,r);
if(l>mid) return query_maxl(o->ch[1],l,r);
int tmp=query_maxl(o->ch[0],l,mid);
if(tmp==mid-l+1) tmp+=query_maxl(o->ch[1],mid+1,r);
return tmp;
}
int query_maxr(Node* o,int l,int r) {
if(o->l==l && o->r==r) return o->maxr[1];
o->pushdown();
int mid=o->l+o->r>>1;
if(r<=mid) return query_maxr(o->ch[0],l,r);
if(l>mid) return query_maxr(o->ch[1],l,r);
int tmp=query_maxr(o->ch[1],mid+1,r);
if(tmp==r-mid) tmp+=query_maxr(o->ch[0],l,mid);
return tmp;
}
int query_max(Node* o,int l,int r) {
if(o->l==l && o->r==r) return o->maxx[1];
o->pushdown();
int mid=o->l+o->r>>1;
if(r<=mid) return query_max(o->ch[0],l,r);
if(l>mid) return query_max(o->ch[1],l,r);
return max(max(query_max(o->ch[0],l,mid),query_max(o->ch[1],mid+1,r)),query_maxr(o->ch[0],l,mid)+query_maxl(o->ch[1],mid+1,r));
}
void update(Node* o,int l,int r,int v) {
if(o->l==l && o->r==r) {
o->change(v);
return ;
}
o->pushdown();
int mid=o->l+o->r>>1;
if(r<=mid) update(o->ch[0],l,r,v);
else if(l>mid) update(o->ch[1],l,r,v);
else update(o->ch[0],l,mid,v), update(o->ch[1],mid+1,r,v);
o->maintain();
return ;
}
void Reverse(Node* o,int l,int r) {
if(o->l==l && o->r==r) {
o->rev();
return ;
}
o->pushdown();
int mid=o->l+o->r>>1;
if(r<=mid) Reverse(o->ch[0],l,r);
else if(l>mid) Reverse(o->ch[1],l,r);
else Reverse(o->ch[0],l,mid), Reverse(o->ch[1],mid+1,r);
o->maintain();
return ;
}