题目链接:http://www.lydsy.com/JudgeOnline/problem.php?id=3224
题意:模拟平衡树
xjb练习splay,注意对重复值的处理,用cnt数组计数。
AC代码:(依旧长的一匹)
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
const int maxn = 1000050;
const int inf = 0x3f3f3f3f;
typedef long long ll;
int l,r,sz,rt;
int f[maxn];
int ch[maxn][2];
int val[maxn];
int Size[maxn];
int cnt[maxn];
inline int get(int x){
return ch[f[x]][1]==x;
}
inline void clear(int x){
ch[x][0]=ch[x][1]=f[x]=cnt[x]=val[x]=Size[x]=0;
}
int get_pre(){
int x=ch[rt][0];
if(x==0) return inf;
while(ch[x][1]) x=ch[x][1];
return x;
}
int get_next(){
int x=ch[rt][1];
if(x==0) return inf;
while(ch[x][0]) x=ch[x][0];
return x;
}
void init(){
for(int i=0;i<=sz;i++){
ch[i][0]=ch[i][1]=Size[i]=cnt[i]=0;
}
rt=sz=0;
}
void update(int x){
if(x){
Size[x]=cnt[x];
if(ch[x][0]) Size[x]+=Size[ch[x][0]];
if(ch[x][1]) Size[x]+=Size[ch[x][1]];
}
}
void rotate(int x){
int y=f[x];
int z=f[y];
int k = get(x);
ch[y][k]=ch[x][k^1]; f[ch[y][k]]=y;
ch[x][k^1]=y; f[y]=x;
f[x]=z;
if(z) ch[z][ch[z][1]==y]=x;
update(y);
update(x);
}
void splay(int x){
int y;
int z;
while(y=f[x]){
z=f[y];
if(z){
if(get(x)==get(y)) rotate(y);
else rotate(x);
}
rotate(x);
}
rt=x;
}
int find(int x){
int now=rt;
int ans=0;
while(true){
if(x<val[now]) now=ch[now][0];
else{
ans+=(ch[now][0]?Size[ch[now][0]]:0);
if(x==val[now]){
splay(now);
return ans+1;
}
ans+=cnt[now];
now=ch[now][1];
}
}
}
int find_k(int x){
int now=rt;
while(true){
if(ch[now][0]&&x<=Size[ch[now][0]]){
now=ch[now][0];
}else{
int tmp=(ch[now][0]?Size[ch[now][0]]:0)+cnt[now];
if(x<=tmp) return val[now];
x-=tmp;
now=ch[now][1];
}
}
}
void del(int x){
find(x);
if(cnt[rt]>1){
cnt[rt]--;
update(rt);
return;
}
if(!ch[rt][0]&&!ch[rt][1]){ clear(rt); rt=0 ;return ;}
if(!ch[rt][0]){
int p=rt;
rt=ch[rt][1];
f[rt]=0;
clear(p);
return ;
}else if(!ch[rt][1]){
int p=rt;
rt=ch[rt][0];
f[rt]=0;
clear(p);
return;
}
int lb=get_pre();
int p=rt;
splay(lb);
f[ch[p][1]]=rt;
ch[rt][1]=ch[p][1];
clear(p);
update(rt);
}
void insert(int x){
if(rt==0){
sz++;
ch[sz][0]=ch[sz][1]=f[sz]=0;
val[sz]=x;Size[sz]=cnt[sz]=1;rt=sz;
return;
}
int now=rt;
int fa=0;
while(true){
if(x==val[now]){
cnt[now]++;
update(now);
update(fa);
splay(now);
break;
}
fa=now;
now=ch[now][val[now]<x];
if(now==0){
sz++;
ch[sz][0]=ch[sz][1]=0;
val[sz]=x;Size[sz]=cnt[sz]=1;
f[sz]=fa;ch[fa][val[fa]<x]=sz;
update(fa);
splay(sz);
break;
}
}
}
int main(){
init();
int q;
scanf("%d",&q);
while(q--){
int k,x;
scanf("%d%d",&k,&x);
if(k==1) insert(x);
else if(k==2) del(x);
else if(k==3) printf("%d\n",find(x));
else if(k==4) printf("%d\n",find_k(x));
else if(k==5){
insert(x);
printf("%d\n",val[get_pre()]);
del(x);
}else if(k==6){
insert(x);
printf("%d\n",val[get_next()]);
del(x);
}
}
return 0;
}