详解
#include<iostream>
#include<cstdio>
#define MAXN 100010
using namespace std;
int root,tot;
struct Splay
{
int fa;
int ch[2];
int val;
int cnt;
int size;
}
t[MAXN];
void maintain(int x)
{
t[x].size=t[t[x].ch[0]].size+t[t[x].ch[1]].size+t[x].cnt;
}
bool get(int x)
{
return x==t[t[x].fa].ch[1];
}
void clear(int x)
{
t[x].ch[0]=t[x].ch[1]=t[x].fa=t[x].val=t[x].cnt=t[x].size=0;
}
void rotate(int x)
{
int y=t[x].fa,z=t[y].fa,chk=get(x);
t[y].ch[chk]=t[x].ch[chk^1];
if(t[x].ch[chk^1])
t[t[x].ch[chk^1]].fa=y;
t[x].ch[chk^1]=y;
t[y].fa=x;
t[x].fa=z;
if(z)
t[z].ch[y==t[z].ch[1]]=x;
maintain(y);
maintain(x);
}
void splay(int x)
{
for(int f=t[x].fa;f=t[x].fa,f;rotate(x))
if(t[f].fa)
rotate(get(x)==get(f)?f:x);
root=x;
}
void insert(int k)
{
if(!root)
{
t[++tot].val=k;
t[tot].cnt++;
root=tot;
maintain(root);
return;
}
int cur=root,f=0;
while(1)
{
if(t[cur].val==k)
{
t[cur].cnt++;
maintain(cur);
maintain(f);
splay(cur);
break;
}
f=cur;
cur=t[f].ch[t[f].val<k];
if(!cur)
{
t[++tot].val=k;
t[tot].cnt++;
t[tot].fa=f;
t[f].ch[t[f].val<k]=tot;
maintain(tot);
maintain(f);
splay(tot);
break;
}
}
}
int rnk(int k)
{
int res=0,cur=root;
while(1)
{
if(k<t[cur].val)
cur=t[cur].ch[0];
else
{
res+=t[t[cur].ch[0]].size;
if(k==t[cur].val)
{
splay(cur);
return res+1;
}
res+=t[cur].cnt;
cur=t[cur].ch[1];
}
}
}
int kth(int k)
{
int cur=root;
while(1)
{
if(t[cur].ch[0]&&k<=t[t[cur].ch[0]].size)
cur=t[cur].ch[0];
else
{
k-=t[t[cur].ch[0]].size+t[cur].cnt;
if(k<=0)
{
splay(cur);
return t[cur].val;
}
cur=t[cur].ch[1];
}
}
}
int pre()
{
int cur=t[root].ch[0];
if(!cur)
return cur;
while(t[cur].ch[1])
cur=t[cur].ch[1];
splay(cur);
return cur;
}
int nxt()
{
int cur=t[root].ch[1];
if(!cur)
return cur;
while(t[cur].ch[0])
cur=t[cur].ch[0];
splay(cur);
return cur;
}
void del(int k)
{
rnk(k);
if(t[root].cnt>1)
{
t[root].cnt--;
maintain(root);
return;
}
if(!t[root].ch[0]&&!t[root].ch[1])
{
clear(root);
root=0;
return;
}
if(!t[root].ch[0])
{
int cur=root;
root=t[root].ch[1];
t[root].fa=0;
clear(cur);
return;
}
if(!t[root].ch[1])
{
int cur=root;
root=t[root].ch[0];
t[root].fa=0;
clear(cur);
return;
}
int cur=root;
int x=pre();
t[t[cur].ch[1]].fa=root;
t[root].ch[1]=t[cur].ch[1];
clear(cur);
maintain(root);
}
int n,op,x;
int main()
{
scanf("%d",&n);
while(n--)
{
scanf("%d%d",&op,&x);
if(op==1)
insert(x);
else if(op==2)
del(x);
else if(op==3)
printf("%d\n",rnk(x));
else if(op==4)
printf("%d\n",kth(x));
else if(op==5)
{
insert(x);
printf("%d\n",t[pre()].val);
del(x);
}
else
{
insert(x);
printf("%d\n",t[nxt()].val);
del(x);
}
}
return 0;
}
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+5;
const int mod=5302121,pri=832211;
inline int read()
{
char c=getchar();int x=0,f=1;
while(c<'0'||c>'9'){if(c=='-')f=-1; c=getchar();}
while(c>='0'&&c<='9'){x=x*10+c-'0'; c=getchar();}
return x*f;
}
inline int rand()
{
static int set=15;
return set=(long long)(set*pri)%mod;
}
int tree[maxn][2], size[maxn], wv[maxn], num[maxn], c[maxn];
int cnt, tot, pre, nxt;
#define lc tree[node][0]
#define rc tree[node][1]
inline void update(int node)
{
if(node) size[node]=size[lc]+size[rc]+c[node];
}
inline void rotate(int &node,int son)
{
int tmp=tree[node][son];
tree[node][son]=tree[tmp][son^1]; tree[tmp][son^1]=node;
size[tmp]=size[node];update(node);
node=tmp;
}
inline void insert(int val,int &node)
{
if(!node)
{
node=++cnt; size[node]=1; wv[node]=rand();
num[node]=val; c[node]=1;
return;
}
else if(val == num[node]) c[node]++,size[node]++;
else
{
size[node]++;
if(val < num[node])
{
insert(val,lc);
if(wv[lc] < wv[node]) rotate(node,0);
}
else
{
insert(val,rc);
if(wv[rc] < wv[node]) rotate(node,1);
}
}
}
inline void del(int val,int &node)
{
if(!node) return;
if(num[node] == val)
{
if(c[node] > 1)
{
c[node]--;
size[node]--;
return;
}
if(lc*rc == 0) node=lc+rc;
else if(wv[lc] > wv[rc]) rotate(node,1),del(val,node);
else rotate(node,0),del(val,node);
}
else
{
size[node]--;
if(num[node] > val) del(val,lc);
else del(val,rc);
}
update(node);
}
inline bool search(int val,int &node)
{
if(!node) return false;
if(num[node] == val) return true;
if(num[node] < val) return search(val,rc);
else return search(val,lc);
}
inline int find_max(int &node)
{
if(!rc) return num[node];
else return find_max(rc);
}
inline int find_min(int &node)
{
if(!lc) return num[node];
else return find_min(lc);
}
inline void find_pre(int val,int node)
{
if(!node) return;
if(val > num[node]) pre=node,find_pre(val,rc);
else find_pre(val,lc);
}
inline void find_nxt(int val,int node)
{
if(!node) return;
if(val < num[node]) nxt=node,find_nxt(val,lc);
else find_nxt(val,rc);
}
inline int find_rank(int val,int node)
{
if(!node) return 1;
if(num[node] == val) return size[lc]+1;
else if(val < num[node]) return find_rank(val,lc);
else return size[lc]+c[node]+find_rank(val,rc);
}
inline int find_kth(int k,int node)
{
if(node == 0) return 0;
if(k <= size[lc]) return find_kth(k,lc);
else if(k > size[lc]+c[node]) return find_kth(k-size[lc]-c[node],rc);
else return num[node];
}
inline int dep(int node)
{
if(!node) return 0;
int l=dep(lc),r=dep(rc);
return (l < r) ? (r+1) : (l+1);
}
inline int middle_trave(int node)
{
if(!node)
{
middle_trave(lc);
printf("%d\t", node);
middle_trave(rc);
}
}
int main()
{
freopen("t.in","r",stdin);
int n=read(),op,x;
while(n--){
if(n == 1)
n=1;
op=read();x=read();
switch(op){
case 1:insert(x,tot);break;
case 2:del(x,tot);break;
case 3:printf("%d\n",find_rank(x,tot));break;
case 4:printf("%d\n",find_kth(x,tot));break;
case 5:pre=0;find_pre(x,tot);printf("%d\n",num[pre]);break;
case 6:nxt=0;find_nxt(x,tot);printf("%d\n",num[nxt]);break;
}
}
return 0;
}