题意:给定n个数,m次查询(a,b,k),表示查询第a个数到第b个数(闭区间[a,b])中的第k小数。其中这m个区间保证没有一个完全覆盖另外一个。
思路:首先考虑没有区间覆盖的意思:对于最小的区间终点,它对应的起点一定是最小的区间起点。首先离散化。然后使用线段树,每个节点(如表示区间为[a,b])维护一个值small表示位于其左儿子的个数。然后对于每个区间(对区间先排序)更新查询即可。时间复杂度为O(nlogn+n+mlogn)。
其中离散化的时候使用二分可以AC,使用map(poj平台不支持unordered_map)会超时。
另外此题还可以使用树状数组、treap、划分树等数据结构。
#include <cstdio>
#include <cstring>
#include <string>
#include <cstdlib>
#include <vector>
#include <queue>
#include <algorithm>
#include <cmath>
#include <map>
#include <set>
#include <iostream>
#define N 100005
#define M 50005
#define INF 0x3fffffff
using namespace std;
struct lt{
int left,right;
int small;
}tree[N<<2];
struct node{
int a,b,k,id;
}p[M];
int res[M];
int n,m;
int s[N],t[N];
bool cmp(node x,node y){
return x.a < y.a;
}
void build(int r,int a,int b){
tree[r].left = a;
tree[r].right = b;
tree[r].small = 0;
if(a == b)
return;
build(r<<1, a, (a+b)/2);
build((r<<1)+1, (a+b)/2+1, b);
}
void update(int r,int x,int flag){
if(tree[r].left == tree[r].right)
return;
int mid = (tree[r].left+tree[r].right)/2;
if(x <= mid){
tree[r].small += flag;
update(r<<1, x, flag);
}else
update(1+(r<<1), x, flag);
}
int solve(int r,int k){
if(tree[r].left == tree[r].right)
return tree[r].left;
int mid = (tree[r].left+tree[r].right)>>1;
if(k <= tree[r].small)
return solve(r<<1, k);
return solve(1+(r<<1), k-tree[r].small);
}
int find(int begin, int end, int x){
int mid;
while(begin <= end){
mid = (begin+end)>>1;
if(t[mid] == x)
return mid;
if(t[mid] > x)
end = mid-1;
else
begin = mid+1;
}
return -1;
}
int main(){
scanf("%d %d",&n,&m);
for(int i = 1;i<=n;i++){
scanf("%d",&s[i]);
t[i] = s[i];
}
sort(t+1, t+n+1);
int num = unique(t+1, t+n+1) - t-1;
for(int i = 1;i<=n;i++)
s[i] = find(1, num, s[i]);
build(1,1,num);
for(int i = 0;i<m;i++){
scanf("%d %d %d",&p[i].a,&p[i].b,&p[i].k);
p[i].id = i;
}
sort(p, p+m, cmp);
int beg = p[0].a, end = beg;
for(int i = 0;i<m;i++){
while(beg<end && beg<p[i].a)
update(1,s[beg++],-1);
if(end < p[i].a)
end = p[i].a;
while(end <= p[i].b)
update(1,s[end++],1);
res[p[i].id] = t[solve(1,p[i].k)];
}
for(int i = 0;i<m;i++)
printf("%d\n",res[i]);
return 0;
}