3196: Tyvj 1730 二逼平衡树
Time Limit: 10 Sec Memory Limit: 128 MBSubmit: 2276 Solved: 937
[ Submit][ Status][ Discuss]
Description
您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
1.查询k在区间内的排名
2.查询区间内排名为k的值
3.修改某一位值上的数值
4.查询k在区间内的前驱(前驱定义为小于x,且最大的数)
5.查询k在区间内的后继(后继定义为大于x,且最小的数)
Input
第一行两个数 n,m 表示长度为n的有序序列和m个操作
第二行有n个数,表示有序序列
下面有m行,opt表示操作标号
若opt=1 则为操作1,之后有三个数l,r,k 表示查询k在区间[l,r]的排名
若opt=2 则为操作2,之后有三个数l,r,k 表示查询区间[l,r]内排名为k的数
若opt=3 则为操作3,之后有两个数pos,k 表示将pos位置的数修改为k
若opt=4 则为操作4,之后有三个数l,r,k 表示查询区间[l,r]内k的前驱
若opt=5 则为操作5,之后有三个数l,r,k 表示查询区间[l,r]内k的后继
Output
对于操作1,2,4,5各输出一行,表示查询结果
Sample Input
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5
Sample Output
4
3
4
9
HINT
1.n和m的数据范围:n,m<=50000
2.序列中每个数的数据范围:[0,1e8]
Source
题解:线段树套splay
先建立一棵区间线段树,然后线段树中的每个节点建立一棵位置在当前点表示的区间的权值splay(小的在左儿子,大的在右儿子)
solve1: 直接把所有在范围内的区间内比他小的数统计一下,然后+1
solve2: 二分答案,通过solve1,计算mid在区间中的排名
solve3: 把这个位置原本的数从所有包含这个位置的区间中删去,然后加入新的值
solve4:从所有在范围内的区间(给出的范围可能在线段树中跨越多个区间)中找前驱最大的
solve5:从所有在范围内的区间(给出的范围可能在线段树中跨越多个区间)中找后继最小的
思路非常清晰明了,但是实现起来异常的麻烦,各种手残简直鬼畜。
刚开始姿势不够优越,TLE。改了姿势后刚好过了,"9724 ms"。。。。。
<span style="font-size:18px;">#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 2000003
#define M 50003
using namespace std;
int n,m;
int ls[4*M],rs[4*M],root[4*M],pd;
int ch[N][3],fa[N],maxn,a[N],size[N],key[N],cnt[N],sz;
void clear(int x)
{
size[x]=key[x]=cnt[x]=ch[x][1]=ch[x][0]=fa[x]=0;
}
int get(int x)
{
return ch[fa[x]][1]==x;
}
void update(int x)
{
size[x]=cnt[x];
if (ch[x][0]) size[x]+=size[ch[x][0]];
if (ch[x][1]) size[x]+=size[ch[x][1]];
}
void rotate(int x)
{
int y=fa[x]; int z=fa[y]; int which=get(x);
if (z)
ch[z][ch[z][1]==y]=x;
fa[x]=z; ch[y][which]=ch[x][which^1]; fa[ch[y][which]]=y;
ch[x][which^1]=y; fa[y]=x;
update(y); update(x);
}
void splay(int x,int &root)
{
for (int f;(f=fa[x]);rotate(x))
if (fa[f])
rotate(get(x)==get(f)?f:x);
root=x;
}
void insert(int &root,int x)
{
if (!root)
{
root=++sz; clear(sz);
size[sz]=cnt[sz]=1; key[sz]=x;
return;
}
int f=0; int now=root;
while(true)
{
if (x==key[now])
{
cnt[now]++; update(now); splay(now,root); return;
}
f=now;
now=ch[now][key[now]<x];
if (!now)
{
sz++; clear(sz);
key[sz]=x; cnt[sz]=size[sz]=1; fa[sz]=f; ch[f][key[f]<x]=sz;
update(f); splay(sz,root); return;
}
}
}
int find(int x,int &root) //查找x的位置
{
int now=root;
while (true)
{
if (now==0) return 0;
if (key[now]==x)
{
splay(now,root);
return now;
}
if (x<key[now])
now=ch[now][0];
if (x>key[now])
now=ch[now][1];
}
}
int findx(int x,int &root) //查找当前区间内比x小的数有多少个
{
int now=root; int ans=0;
while (true)
{
if (!now) return ans;
if (x<key[now])
now=ch[now][0];
else
{
ans+=(ch[now][0]?size[ch[now][0]]:0);
if (x==key[now])
{
splay(now,root); pd=true; return ans;
}
ans+=cnt[now];
now=ch[now][1];
}
}
}
int pre(int root)
{
int now=ch[root][0];
while (ch[now][1]) now=ch[now][1];
return now;
}
int next(int root)
{
int now=ch[root][1];
while (ch[now][0]) now=ch[now][0];
return now;
}
void del(int &root,int x)
{
splay(x,root);
if (cnt[root]>1)
{
cnt[root]--; update(root); return;
}
if (!ch[root][1]&&!ch[root][0])
{
clear(root); root=0; return ;
}
if (!ch[root][1])
{
int old=root; root=ch[root][0]; fa[root]=0; clear(old); return;
}
if (!ch[root][0])
{
int old=root; root=ch[root][1]; fa[root]=0; clear(old); return;
}
int k=pre(root); int old=root; splay(k,root);
ch[k][1]=ch[old][1]; fa[ch[k][1]]=k; clear(old);
update(k); return;
}
void pointchange(int now,int l,int r,int x,int v)
{
insert(root[now],v);
if (l==r) return;
int mid=(l+r)/2;
if (x<=mid)
pointchange(now<<1,l,mid,x,v);
else
pointchange(now<<1|1,mid+1,r,x,v);
}
int solve1(int now,int l,int r,int ll,int rr,int k)
{
if (l>=ll&&r<=rr)
{
return findx(k,root[now]);
}
int mid=(l+r)/2;
int ans=0;
if (ll<=mid)
ans+=solve1(now<<1,l,mid,ll,rr,k);
if (rr>mid)
ans+=solve1(now<<1|1,mid+1,r,ll,rr,k);
return ans;
}
void solve2(int l,int r,int k)
{
int head=0,tail=maxn+1,mid;
while (head<tail)
{
mid=(head+tail)/2;
if (solve1(1,1,n,l,r,mid)<k)
head=mid+1;
else tail=mid;
}
printf("%d\n",head-1);
}
void solve3(int now,int l,int r,int x,int v)
{
int t=find(a[x],root[now]);
del(root[now],t);
insert(root[now],v);
if (l==r) return;
int mid=(l+r)/2;
if (x<=mid)
solve3(now<<1,l,mid,x,v);
else
solve3(now<<1|1,mid+1,r,x,v);
}
int find_next_min(int rt,int x)
{
int now=root[rt],t=0,ans=-1;
while (now)
{
if (key[now]<x)
{
if (ans<key[now])ans=key[now];
now=ch[now][1];
}
else now=ch[now][0];
}
return ans;
}
int find_next_max(int rt,int x)
{
int now=root[rt],t=0,ans=1000000000;
while (now)
{
if (key[now]>x)
{
if (ans>key[now]) ans=key[now];
now=ch[now][0];
}
else now=ch[now][1];
}
return ans;
}
int solve4(int now,int l,int r,int ll,int rr,int x)
{
if (l>=ll&&r<=rr)
{
return find_next_min(now,x);
}
int mid=(l+r)/2;
int maxn=0;
if (ll<=mid)
maxn=max(maxn,solve4(now<<1,l,mid,ll,rr,x));
if (rr>mid)
maxn=max(maxn,solve4(now<<1|1,mid+1,r,ll,rr,x));
return maxn;
}
int solve5(int now,int l,int r,int ll,int rr,int x)
{
if (l>=ll&&r<=rr)
{
return find_next_max(now,x);
}
int mid=(l+r)/2;
int minn=1000000000;
if (ll<=mid)
minn=min(minn,solve5(now<<1,l,mid,ll,rr,x));
if (rr>mid)
minn=min(minn,solve5(now<<1|1,mid+1,r,ll,rr,x));
return minn;
}
int main()
{
freopen("input.txt","r",stdin);
freopen("my.out","w",stdout);
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
scanf("%d",&a[i]),maxn=max(maxn,a[i]);
for (int i=1;i<=n;i++)
pointchange(1,1,n,i,a[i]);
for (int i=1;i<=m;i++)
{
int op,x,y,k; scanf("%d%d%d",&op,&x,&y);
if (op!=3) scanf("%d",&k);
switch(op)
{
case 1: printf("%d\n",solve1(1,1,n,x,y,k)+1); break; //注意rank是比他小的个数+1
case 2: solve2(x,y,k); break;
case 3: solve3(1,1,n,x,y); maxn=max(maxn,y); a[x]=y; break;
case 4: printf("%d\n",solve4(1,1,n,x,y,k)); break;
case 5: printf("%d\n",solve5(1,1,n,x,y,k)); break;
}
}
} </span>