#2445 普通平衡树
输入
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
插入x数
删除x数(若有多个相同的数,因只删除一个)
查询x数的排名(若有多个相同的数,因输出最小的排名)
查询排名为x的数
求x的前驱(前驱定义为小于x,且最大的数)
求x的后继(后继定义为大于x,且最小的数)
输出
第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6)
样例输入
8
1 10
1 20
1 30
3 20
4 2
2 10
5 25
6 -1
样例输出
2
20
20
20
提示
n<=100000 所有数字均在-10^7 到10^7内
SOL
Splay模板题。
首先,splay是基于BST得到的一种伸展树。
BST,即Binary Search Tree(二叉查找树),满足左儿子权值<当前节点权值<右儿子权值。
二叉查找树能够支持多种动态集合操作,只要所维护的数据集合存在偏序关系(简单来
说就是定义了小于等于)。因此,在信息学竞赛中,二叉查找树起着非常重要的作用,它可
以被用来表示有序集合、建立索引或优先队列等。我们也常常利用它动态地维护一个有序数
集,利用二叉查找树找到新加入序列中的数插入的位置。
BST本身可以支持删除、修改、查询等诸多操作,但是复杂度完全由树的深度决定,因此容易被卡。而Splay是基于BST得到的一种平衡树,它可以通过旋转操作来使均摊复杂度达到O(logN)O(logN)O(logN)。
旋转分成左旋和右旋,代码里是将这两种操作合并到了一起的。
Splay伸展操作
1.节点 x 的父节点 y 是根节点。这时,如果 x 是 y 的左孩子,我们进行一次 Zig
(右旋)操作;如果 x 是 y 的右孩子,则我们进行一次 Zag(左旋)操作。经过旋转,x 成
为二叉查找树 S 的根节点,调整结束。
2.节点x 的父节点y 不是根节点,y 的父节点为z,且x 与y 同时是各自父节点
的左孩子或者同时是各自父节点的右孩子。这时,我们进行一次Zig-Zig操作或者Zag-Zag操作。
3.节点x的父节点y不是根节点,y的父节点为z,x与y中一个是其父节点的左孩子
而另一个是其父节点的右孩子。这时,我们进行一次Zig-Zag操作或者Zag-Zig 操作。
关于修改、查询的操作
(1) find(x,S):判断元素x是否在伸展树S表示的有序集中。首先,访问根节点,如果x比
根节点权值小则访问左儿子;如果x比根节点权值大则访问右儿子;如果权值相等,则说明x
在树中;如果访问到空节点,则x不在树中。如果x在树中,则再执行Splay(x,S)调整伸展树。
(2) insert(x,S):将元素x插入伸展树S表示的有序集中。首先,访问根节点,如果x比根节
点权值小则访问左儿子;如果x比根节点权值大则访问右儿子;如果访问到空节点t,则把x插
入该节点,然后执行Splay(t,S)。
(3)merge(S1,S2):将两个伸展树S1与S2合并成为一个伸展树。其中S1的所有元素都小于S2
的所有元素。首先,我们找到伸展树S1中最大的一个元素x,再通过Splay(x,S1)将x调整到伸
展树S1的根。然后再将S2作为x节点的右子树。这样,就得到了新的伸展树S。
(4) delete(x,S):把节点x从伸展树表示的有序集中删除。首先,执行Splay(x,S)将x旋转至
根节点。然后取出x的左子树S1和右子树S2,执行merge(S1,S2)把两棵子树合并成S。
Besides,我们还可以用Splay来维护序列,而有时我们需要对序列上某个区间进行操作,这种情况
下,Splay还能支持提取区间操作。比如我们要对[l,r]进行操作,则我们在Splay
中把l -1对应的节点Splay到根,把 r +1Splay到根的右儿子处,之后 r +1的左儿子及它左儿
子的子树,实际上就是区间[l,r]所对应的平衡树。
在Splay上我们也可以类似线段树,打上lazy-tag标记,表示对这个Splay子树中的所有节
点同时进行某些操作。利用标记下放操作,我们只要保证访问某个节点时,根到它路径上没
有标记存在即可保证正确性。
代码:
结构体数组版:
#include<bits/stdc++.h>
#define f(p) t[p].f
#define lc(p) t[p].lc
#define rc(p) t[p].rc
#define v(p) t[p].v
#define siz(p) t[p].siz
#define N 100005
using namespace std;
inline int rd(){
static char ch=0;int register data=0,w=1;
while(!isdigit(ch)&&ch!='-')ch=getchar();
if(ch=='-')w=-1,ch=getchar();
while(isdigit(ch))data=(data<<1)+(data<<3)+ch-'0',ch=getchar();
return data*w;
}
int rt,tot,m[N];
struct Splay_Tree{int f,lc,rc,v,siz;}t[N];
inline int get(int p){return rc(f(p))==p;}
inline void update(int p){siz(p)=siz(lc(p))+siz(rc(p))+m[p];}
inline void rotate(int x){//旋转
int register y=f(x),z=f(y),d=get(x),d1=get(y);
if(d)rc(y)=lc(x),f(rc(y))=y,lc(x)=y;
else lc(y)=rc(x),f(lc(y))=y,rc(x)=y;
f(y)=x,f(x)=z;
if(z){if(d1)rc(z)=x;else lc(z)=x;}
update(x),update(y);
}
inline void splay(int p,int S){//伸展
while(f(p)!=S){
if(f(f(p))!=S){
if(get(p)==get(f(p)))rotate(f(p));
else rotate(p);
}
else rotate(p);
}
update(p);
if(!S)rt=p;
}
inline void insert(int val){//插入
int register u=rt,v=0,dir;
while(u){
++siz(u);v=u;
if(val==v(u)){m[u]++;return;}
if(val<v(u))dir=0,u=lc(u);
else dir=1,u=rc(u);
}
if(!rt)rt=u;//
u=++tot;
f(u)=v,v(u)=val,siz(u)=m[u]=1;
if(v)(dir==0?lc(v):rc(v))=u;
splay(u,0);
}
inline int findmin(int u){if(u)while(lc(u))u=lc(u);return u;}//找到最小值的编号
inline int findmax(int u){if(u)while(rc(u))u=rc(u);return u;}//找到最大值的编号
inline void merge(int u,int v){
int register w=findmax(u);
splay(w,0);
f(v)=w;rc(w)=v;
update(w);
}
inline int find(int x){//找到权值为x的节点的编号
int register u=rt;
while(u){
if(v(u)==x)return u;
if(v(u)<x)u=rc(u);
else u=lc(u);
}
return 0;
}
inline void erase(int u){//删除
u=find(u);
if(!u)return;
splay(u,0);
if(m[u]>1){siz(u)--;m[u]--;return;}
if(siz(u)==1)rt=0;
else if(!lc(u)||!rc(u)){rt=lc(u)+rc(u);f(rt)=0;}
else{f(lc(u))=f(rc(u))=0;merge(lc(u),rc(u));}
}
inline int pre(int x){//x的前驱
int register u=rt,ans=0;
while(u){
if(v(u)<x)ans=v(u);
if(v(u)>=x)u=lc(u);
else u=rc(u);
}
return ans;
}
inline int nxt(int x){//x的后继
int register u=rt,ans=0;
while(u){
if(v(u)>x)ans=v(u);
if(v(u)<=x)u=rc(u);
else u=lc(u);
}
return ans;
}
inline int rank(int v){//权值v的排名
int register u=rt,ans=0;
while(u){
if(v(u)>v)u=lc(u);
else{
if(v==v(u)){return ans+siz(lc(u))+1;}
if(v<v(u))u=lc(u);
else ans+=siz(lc(u))+m[u],u=rc(u);
}
}
return ans;
}
inline int kth(int k){//求第k大
int register u=rt;
while(u){
if(siz(lc(u))<k&&siz(lc(u))+m[u]>=k)return v(u);
if(siz(lc(u))>=k)u=lc(u);
else k-=siz(lc(u))+m[u],u=rc(u);
}
return 0;
}
signed main(){
int register ins=rd();
while(ins--){
int register op=rd(),x=rd();
switch(op){
case 1:{insert(x);break;}
case 2:{erase(x);break;}
case 3:{printf("%d\n",rank(x));break;}
case 4:{printf("%d\n",kth(x));break;}
case 5:{printf("%d\n",pre(x));break;}
case 6:{printf("%d\n",nxt(x));break;}
}
}
return 0;
}
静态数组版:
#include<bits/stdc++.h>
#define N 100005
using namespace std;
inline int rd(){
static char ch=0;int register data=0,w=1;
while(!isdigit(ch)&&ch!='-')ch=getchar();
if(ch=='-')w=-1,ch=getchar();
while(isdigit(ch))data=(data<<1)+(data<<3)+ch-'0',ch=getchar();
return data*w;
}
int rt,tot,m[N];
int f[N],v[N],siz[N],ch[N][2];
inline int get(int p){return ch[f[p]][1]==p;}
inline void update(int p){siz[p]=siz[ch[p][0]]+siz[ch[p][1]]+m[p];}
inline void rotate(int x){
int register y=f[x],z=f[y],d=get(x),d1=get(y);
ch[y][d]=ch[x][d^1];f[ch[y][d]]=y;
ch[x][d^1]=y;f[y]=x;f[x]=z;
if(z)ch[z][d1]=x;
update(x),update(y);
}
inline void splay(int p,int S){
while(f[p]!=S){
if(f[f[p]]!=S){
if(get(p)==get(f[p]))rotate(f[p]);
else rotate(p);
}
else rotate(p);
}
update(p);
if(!S)rt=p;
}
inline void insert(int val){
int register u=rt,vv=0,dir;
while(u){
++siz[u];vv=u;
if(val==v[u]){m[u]++;return;}
if(val<v[u])dir=0,u=ch[u][0];
else dir=1,u=ch[u][1];
}
if(!rt)rt=u;//
u=++tot;
f[u]=vv,v[u]=val,siz[u]=m[u]=1;
if(vv)(dir==0?ch[vv][0]:ch[vv][1])=u;
splay(u,0);
}
inline int findmin(int u){if(u)while(ch[u][0])u=ch[u][0];return u;}//找到子数最小值的编号
inline int findmax(int u){if(u)while(ch[u][1])u=ch[u][1];return u;}
inline void merge(int u,int v){
int register w=findmax(u);
splay(w,0);
f[v]=w;ch[w][1]=v;
update(w);
}
inline int find(int x){//找到权值为x的节点的编号
int register u=rt;
while(u){
if(v[u]==x)return u;
if(v[u]<x)u=ch[u][1];
else u=ch[u][0];
}
return 0;
}
inline void erase(int u){
u=find(u);
if(!u)return;
splay(u,0);
if(m[u]>1){siz[u]--;m[u]--;return;}
if(siz[u]==1)rt=0;
else if(!ch[u][0]||!ch[u][1]){rt=ch[u][0]+ch[u][1];f[rt]=0;}
else{f[ch[u][0]]=f[ch[u][1]]=0;merge(ch[u][0],ch[u][1]);}
}
inline int pre(int x){
int register u=rt,ans=0;
while(u){
if(v[u]<x)ans=v[u];
if(v[u]>=x)u=ch[u][0];
else u=ch[u][1];
}
return ans;
}
inline int nxt(int x){
int register u=rt,ans=0;
while(u){
if(v[u]>x)ans=v[u];
if(v[u]<=x)u=ch[u][1];
else u=ch[u][0];
}
return ans;
}
inline int rank(int x){
int register u=rt,ans=0;
while(u){
if(v[u]>x)u=ch[u][0];
else{
if(x==v[u]){return ans+siz[ch[u][0]]+1;}
if(x<v[u])u=ch[u][0];
else ans+=siz[ch[u][0]]+m[u],u=ch[u][1];
}
}
return ans;
}
inline int kth(int k){
int register u=rt;
while(u){
if(siz[ch[u][0]]<k&&siz[ch[u][0]]+m[u]>=k)return v[u];
if(siz[ch[u][0]]>=k)u=ch[u][0];
else k-=siz[ch[u][0]]+m[u],u=ch[u][1];
}
return 0;
}
signed main(){
int register ins=rd();
while(ins--){
int register op=rd(),x=rd();
switch(op){
case 1:{insert(x);break;}
case 2:{erase(x);break;}
case 3:{printf("%d\n",rank(x));break;}
case 4:{printf("%d\n",kth(x));break;}
case 5:{printf("%d\n",pre(x));break;}
case 6:{printf("%d\n",nxt(x));break;}
}
}
return 0;
}