题目大意:给你一个序列,1~n,m个操作,每个操作 a、b,就是把从第a个到第b个数字拿出来,翻转一下,放到后面,问你最终序列式多少?
思路:把a~b拿出来,再放回去,可以用伸展树的分裂,合并来做,翻转的话,就在每个节点上放一个延时标记 flip ,就像线段树一样,访问的时候push_down() 就行。
这里要注意,每次push_down() 要放在比较前面,因为比较是跟左儿子有关的,翻转会交换左右儿子。
这道题还有一个非常重要的地方要注意,就是要增设一个虚拟节点,因为 split 要保证 left != null。
还有用 null 来代替 NULL 会省很多的麻烦,而且不容易出错。。
第一次写的伸展树,代码很多是参考书上和别人来的,唯一印象最深刻的是我调试了一个下午加晚上。。 = =
代码如下:
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
using namespace std;
struct Node
{
Node* ch[2];
int v;//节点权值
int size;//子树大小,即节点总数
int flip;//翻转延时标记
Node(){}
Node(int v,Node* null) : v(v) {ch[0] = ch[1] = null;size = v + 1;flip = 0;}
int cmp(int x)
{
int ss = ch[0]->size;
if(x == ss+ 1) return -1;
else return x<ss + 1 ? 0:1;
}
void maintain()
{
size = 1;
size += ch[0]->size + ch[1]->size;
}
};
struct Splay
{
Node* root;
Node* null;
void init(int n)
{
null = new Node;//空指针
null->ch[0] = null->ch[1] = null;
null->v = null->size = null->flip = 0;
root = null;
build(root,n);
}
void build(Node* &o,int n)
{
if(n>=0)//增设虚拟节点0,split 限制 left != null !
{
o = new Node(n,null);
build(o->ch[0],n - 1);
}
}
void rotate(Node* &o,int d)
{
Node* k = o->ch[d^1];o->ch[d^1] = k->ch[d];k->ch[d] = o;
o->maintain();k->maintain();o = k;
}
void debug(Node* o)
{
if(o == null) return ;
printf("%d(",o->v);
if(o->ch[0]!=null) printf("%d,",o->ch[0]->v);
else printf("null,");
if(o->ch[1]!=null) printf("%d",o->ch[1]->v);
else printf("null");
puts(")");
if(o->ch[0]!=null) debug(o->ch[0]);
if(o->ch[1]!=null) debug(o->ch[1]);
}
void push_down(Node* &o)
{
if(o->flip)
{
o->flip = 0;
swap(o->ch[0],o->ch[1]);
o->ch[0]->flip ^= 1;
o->ch[1]->flip ^= 1;
}
}
void splay(Node* &o,int k)
{
push_down(o);//先 push_down() 再比较
int d = o->cmp(k);
if(d == 1) k -= o->ch[0]->size + 1;
if(d!=-1)
{
Node* p = o->ch[d];
push_down(p);//同理,要先 push_down()
int d2 = p->cmp(k);
if(d2 != -1)
{
int k2 = d2 == 0? k: k - p->ch[0]->size - 1;
splay(p->ch[d2],k2);
if(d == d2) rotate(o,d^1);
else rotate(o->ch[d],d);
}
rotate(o,d^1);
}
}
Node* merge(Node* left,Node* right)
{
splay(left,left->size);
left->ch[1] = right;
left->maintain();
return left;
}
void split(Node* o,int k,Node* &left,Node* &right)
{
splay(o,k);
left = o;
right = o->ch[1];
o->ch[1] = null;
left->maintain();
}
void solve(int a,int b)
{
Node *left,*right,*mid,*o;
split(root,a,left,o);
split(o,b - a + 1,mid,right);
mid->flip ^= 1;
root = merge(merge(left,right),mid);
}
void print(Node* &root)
{
if(root == null) return ;
push_down(root);
print(root->ch[0]);
if(root->v > 0) printf("%d\n",root->v);
print(root->ch[1]);
}
} sp;
int main()
{
int n,m;
while(~scanf("%d%d",&n,&m))
{
sp.init(n);
int a,b;
for(int i = 1;i<=m;i++)
{
scanf("%d%d",&a,&b);
sp.solve(a,b);
}
sp.print(sp.root);
}
return 0;
}