写一棵平衡树,要求实现以下操作,插入、删除数字,查询某个数字是第几大的,查询第k大的数字是几。
#include <cstdio>
const int MAXINT=~0u>>1;
struct SplayNode {
SplayNode *ls,*rs,*f;
int x,size;
SplayNode *clear(int xx,SplayNode *ff=NULL) {
f=ff;
x=xx;
ls=rs=NULL;
size=1;
return this;
}
void repairSize() {
size=1;
if (ls) size+=ls->size;
if (rs) size+=rs->size;
}
void rot() {
SplayNode *x=this,*y=this->f;
if (x==y->ls) {
y->ls=x->rs;
x->rs=y;
if (y->ls) y->ls->f=y;
} else {
y->rs=x->ls;
x->ls=y;
if (y->rs) y->rs->f=y;
}
x->f=y->f;
y->f=x;
if (x->f)
if (x->f->ls==y) x->f->ls=x;
else x->f->rs=x;
y->repairSize();
x->repairSize();
}
int dir() {
if (this->f)
if (this==this->f->ls) return -1;
else return 1;
return 0;
}
};
SplayNode b[200010],*bp,*root;
void print(int root) {
printf("Root: %d\n",root);
for (SplayNode *cur=b;cur!=bp;cur++) {
printf("Node %d:\n",(int)(cur-b));
printf(" ls:%d rs:%d f:%d\n",(int)(cur->ls-b),(int)(cur->rs-b),(int)(cur->f-b));
printf(" x:%d size:%d\n",cur->x,cur->size);
int x,num,size;
}
}
SplayNode *splay(SplayNode *x,SplayNode *f=NULL) {
while (x->f!=f) {
if (x->f->f==f) x->rot();
else if (x->dir()==x->f->dir()) {
x->f->rot();
x->rot();
} else {
x->rot();
x->rot();
}
}
return x;
}
SplayNode *less(SplayNode *root,int x) {
SplayNode *ans=NULL;
for (SplayNode *cur=root;cur;) {
if (cur->x>=x) cur=cur->ls;
else {
ans=cur;
cur=cur->rs;
}
}
return splay(ans,root->f);
}
SplayNode *greater(SplayNode *root,int x) {
SplayNode *ans=NULL;
for (SplayNode *cur=root;cur;) {
if (cur->x<=x) cur=cur->rs;
else {
ans=cur;
cur=cur->ls;
}
}
return splay(ans,root->f);
}
void insert(int x) {
root=less(root,x);
root->rs=greater(root->rs,x);
if (root->rs->ls!=NULL) return;
root->rs->ls=(bp++)->clear(x,root->rs);
root->rs->size++;
root->size++;
}
void erase(int x) {
root=less(root,x);
root->rs=greater(root->rs,x);
if (root->rs->ls==NULL) return;
root->rs->ls=NULL;
root->rs->size--;
root->size--;
}
int kth(int k) {
if (k>root->size) return MAXINT;
for (SplayNode *cur=root;cur;) {
if (cur->ls) {
if (k<=cur->ls->size) {
cur=cur->ls;
continue;
} else k-=cur->ls->size;
}
if (k==1) {
root=splay(cur,root->f);
return cur->x;
} else k--;
cur=cur->rs;
}
}
int count(int x) {
root=less(root,x);
root->rs=greater(root->rs,x);
if (root->ls) return root->ls->size+1;
else return 1;
}
void clear() {
bp=b;
root=(bp++)->clear(-MAXINT);
root->rs=(bp++)->clear(MAXINT,root);
root->size++;
}
int main() {
int q,x;
char c;
scanf("%d",&q);
clear();
while (q--) {
scanf(" %c%d",&c,&x);
if (c=='I') insert(x);
else if (c=='D') erase(x);
else if (c=='K') {
int ans=kth(x+1);
if (ans==MAXINT) printf("invalid\n");
else printf("%d\n",ans);
} else {
printf("%d\n",count(x)-1);
}
//print(root-b);
}
return 0;
}