SplayTree,即伸展树,是竞赛中很常用的一种平衡树,它可以实现序列的分裂合并,并且保证均摊复杂度是O(nlogn)。
与Treep类似的,SplayTree也是基于旋转的,Splay,即伸展操作的实质就是在一棵子树中找到一个节点并把它旋转到根,复杂度靠谱的Splay有三种旋转(注意不要擅自修改旋转的方式,因为SplayTree复杂度的证明正是基于它的旋转方式的,否则可能会变成复杂度不靠谱的SpalyTree等):
1.目标节点是根节点的儿子节点:
…………..root………..
…………./………………
…………k……….(图略劣质别介意)
这时只需要进行一次旋转操作
2.目标节点是根节点下两层节点并且与根节点成直线(即是根节点的左儿子的左儿子,或是根节点的右儿子的右儿子)
………..root……
………/……………
……..s…………….
……./……………..
……k………………
这时只要连续两次旋转根节点就可以
3.目标节点是根节点下两层节点并且与根节点成折线(即是根节点的左儿子的右儿子,或是根节点的右儿子的左儿子)
………..root………..
………../………………
……….s………………
………..|………………
………..k…………….(由于\显示不出来,所以用|代替)
这时只要先将k旋到s的位置,再对root进行一次旋转就可以
看起来好像情况很复杂,但是,看了某汝佳的书后才发现代码可以这么缩:
void splay(pnode &p,int k){
p->pushdown();
int d1=p->cmp(k);
if(d1!=-1){
p->ch[d1]->pushdown();int d2=p->ch[d1]->cmp(k);
if(d2!=-1){
splay(p->ch[d1]->ch[d2],k);
if(d1==d2)rot(p,d1^1);
else rot(p->ch[d1],d2^1);
}
rot(p,d1^1);
}
}
另外,为了便于写代码,对cmp函数做一个小改动,即在比较的同时对k进行一个改动,若搜索的是右子树,就将k减去(左子树的个数+1),这样方便下一步的splay,代码如下:
int cmp(int &k){
if(k<ch[0]->num+1)return 0;
if(k==ch[0]->num+1)return -1;
k-=ch[0]->num+1;return 1;
}
这样,splay操作就完成了,再次提醒:
千万不要自作聪明的修改某些操作,建议将每种旋转后的树形都画一下,然后你会发现,和自己想出来的旋转方式最终的结果是不一样的。
SplayTree最常见的应用是维护一个可删除插入的序列,有了splay这个操作,我们就可以方便的实现各种操作了。
删除:
删除一段序列,只要把要删除的序列的第一个元素前一个元素旋到树根,再把要删除的序列的最后一个元素的后一个元素旋到树根的右儿子,这时,树根的右儿子的左儿子就是需要删除的序列。不要忘记,删除完之后标记要更新。
插入:
为了插入一段序列,首先要把插入的序列变成一个SplayTree(如果一个个插入的话,遇到操作次数多的题目就会TLE),怎么快速的将一段序列变成一个SplayTree?可以每次将序列从中间分成左右两段,左边的放左子树,右边的放右子树,代码如下:
pnode build(int l,int r,int*arr){
if(l>r)return null;
int mid=(l+r)/2;
pnode p=newnode(*(arr+mid));
p->ch[0]=build(l,mid-1,arr);p->ch[1]=build(mid+1,r,arr);
p->maintain();
return p;
}
这样,将插入位置的前一个元素旋到树根,再讲插入位置的后一个元素旋到根的右儿子(这时,根的右儿子的左儿子一定是空),再把造好的新SplayTree作为根的右儿子的左子树。插入玩以后依然要更新标记。
其他的一些操作如求和,区间修改等就不多说了,反正就是利用伸展操作把需要处理的区间拎到一棵单独的子树上,然后各种标记,另外,懒标记的思想在SplayTree也是适用的,这使得SplayTree的功能要强大的多。
另外,在学习SplayTree的一定不能错过BZOJ1500,附上此题代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
#define INF -1e9
#define maxn 500006
using namespace std;
struct node{
node* ch[2];
int key,num,sum,maxs,maxl,maxr,tag,flip;
void add_tag(int k){
key=tag=k;sum=k*num;
if(k>0)maxs=maxl=maxr=k*num;
else maxs=maxl=maxr=k;
}
void add_flip(){
swap(maxl,maxr);swap(ch[0],ch[1]);flip^=1;
}
void pushdown(){
if(tag!=INF)ch[0]->add_tag(tag),ch[1]->add_tag(tag),tag=INF;
if(flip)ch[0]->add_flip(),ch[1]->add_flip(),flip=0;
}
int maintain(){
sum=ch[0]->sum+key+ch[1]->sum;num=ch[0]->num+1+ch[1]->num;
maxl=max(max(ch[0]->maxl,ch[0]->sum+key),ch[0]->sum+key+ch[1]->maxl);
maxr=max(max(ch[1]->maxr,ch[1]->sum+key),ch[1]->sum+key+ch[0]->maxr);
maxs=max(key,max(ch[0]->maxr+key+ch[1]->maxl,max(max(ch[0]->maxs,ch[1]->maxs),max(ch[0]->maxr+key,key+ch[1]->maxl))));
}
int cmp(int &k){
if(k<ch[0]->num+1)return 0;
if(k==ch[0]->num+1)return -1;
k-=ch[0]->num+1;return 1;
}
}nul;
typedef node* pnode;
pnode rt,null=&nul;
pnode newnode(int key){
pnode p=new node;p->key=p->sum=p->maxs=p->maxl=p->maxr=key;p->ch[0]=p->ch[1]=null;
p->tag=INF;p->flip=0;p->num=1;
return p;
}
void rot(pnode &p,int d){
pnode k=p->ch[d^1];p->ch[d^1]=k->ch[d];k->ch[d]=p;
p->maintain();k->maintain();p=k;
}
void splay(pnode &p,int k){
p->pushdown();
int d1=p->cmp(k);
if(d1!=-1){
p->ch[d1]->pushdown();int d2=p->ch[d1]->cmp(k);
if(d2!=-1){
splay(p->ch[d1]->ch[d2],k);
if(d1==d2)rot(p,d1^1);
else rot(p->ch[d1],d2^1);
}
rot(p,d1^1);
}
}
void del(pnode p){
if(p==null)return;
del(p->ch[0]);del(p->ch[1]);
delete p;
}
pnode build(int l,int r,int*arr){
if(l>r)return null;
int mid=(l+r)/2;
pnode p=newnode(*(arr+mid));
p->ch[0]=build(l,mid-1,arr);p->ch[1]=build(mid+1,r,arr);
p->maintain();
return p;
}
void print_splay(pnode p){
if(p==null)return;
p->pushdown();
print_splay(p->ch[0]);
// printf("%d ",p->key);
print_splay(p->ch[1]);
}
int _read(){
char ch=getchar();int p,sum=0;
while((ch!='-')&&(!(ch>='0'&&ch<='9')))ch=getchar();
if(ch=='-')p=-1,ch=getchar();else p=1;
while(ch>='0'&&ch<='9')sum=sum*10+ch-48,ch=getchar();
return sum*p;
}
int n,m,c[maxn];
int main(){
null->ch[0]=null->ch[1]=null;null->num=null->flip=null->sum=null->key=0;null->tag=null->maxl=null->maxr=null->maxs=INF;
freopen("sequence.in","r",stdin);
freopen("sequence.out","w",stdout);
n=_read();m=_read();
for(int i=1;i<=n;i++)c[i]=_read();c[n+1]=0;
rt=build(0,n+1,c);
for(int t=1;t<=m;t++){
char s[30];scanf("%s",s);
if(s[0]=='I'){
int x=_read(),tot=_read();x++;
for(int i=1;i<=tot;i++)c[i]=_read();
splay(rt,x);splay(rt->ch[1],1);
rt->ch[1]->ch[0]=build(1,tot,c);
rt->ch[1]->maintain();rt->maintain();
}else
if(s[0]=='D'){
int x=_read(),tot=_read();x++;
splay(rt,x-1);splay(rt->ch[1],tot+1);
del(rt->ch[1]->ch[0]);rt->ch[1]->ch[0]=null;
rt->ch[1]->maintain();rt->maintain();
}else
if(s[2]=='K'){
int x=_read(),tot=_read(),y=_read();x++;
splay(rt,x-1);splay(rt->ch[1],tot+1);
rt->ch[1]->ch[0]->add_tag(y);
}else
if(s[0]=='R'){
int x=_read(),tot=_read();x++;
splay(rt,x-1);splay(rt->ch[1],tot+1);
rt->ch[1]->ch[0]->add_flip();
}else
if(s[0]=='G'){
int x=_read(),tot=_read();x++;
splay(rt,x-1);splay(rt->ch[1],tot+1);
printf("%d\n",rt->ch[1]->ch[0]->sum);
}else{
splay(rt,1);splay(rt->ch[1],rt->ch[1]->num);
printf("%d\n",rt->ch[1]->ch[0]->maxs);
}
// print_splay(rt);
}
return 0;
}