题目描述
给你一个N*N的矩阵,不用算矩阵乘法,但是每次询问一个子矩形的第K小数。
输入
第一行两个数N,Q,表示矩阵大小和询问组数;
接下来N行N列一共N*N个数,表示这个矩阵;
再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。
接下来N行N列一共N*N个数,表示这个矩阵;
再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。
输出
对于每组询问输出第K小的数。
样例输入
2 2
2 1
3 4
1 2 1 2 1
1 1 2 2 3
2 1
3 4
1 2 1 2 1
1 1 2 2 3
样例输出
1
3
3
提示
矩阵中数字是109以内的非负整数;
20%的数据:N<=100,Q<=1000;
40%的数据:N<=300,Q<=10000;
60%的数据:N<=400,Q<=30000;
100%的数据:N<=500,Q<=60000。
考虑暴力,对于每次询问二分答案,求权值$\le mid$的点的个数是否有$k$个,显然时间复杂度爆炸。
我们将所有询问一起二分,也就是整体二分。
先将每个点的权值排序,每次将权值$\le mid$的点加入到矩阵中,然后对当前要处理的所有询问进行查询。
如果查询矩形内点个数大于等于$k$那么说明这个询问的答案要在$[l,mid]$中,就将这个询问归为左区间,否则将询问的$k$减掉这次查询的结果,归为右区间。这样递归下去即可得到每个询问的答案。
至于查询矩形内点数用二维树状数组差分一下即可。
每层处理完不要忘记把树状数组清空。
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<cstdio>
#include<bitset>
#include<vector>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
int n,m;
int v[510][510];
struct lty
{
int x,y,val;
}a[250010];
int cnt;
int num;
struct miku
{
int a,b,c,d,k;
}q[60010];
int ans[60010];
int s[60010];
int ql[60010];
int qr[60010];
bool cmp(lty a,lty b)
{
return a.val<b.val;
}
void add(int x,int y,int val)
{
for(int i=x;i<=n;i+=i&-i)
{
for(int j=y;j<=n;j+=j&-j)
{
v[i][j]+=val;
}
}
}
int ask(int x,int y)
{
int res=0;
for(int i=x;i;i-=i&-i)
{
for(int j=y;j;j-=j&-j)
{
res+=v[i][j];
}
}
return res;
}
int query(int a,int b,int c,int d)
{
return ask(c,d)-ask(c,b-1)-ask(a-1,d)+ask(a-1,b-1);
}
void solve(int l,int r,int L,int R)
{
if(L>R)
{
return ;
}
if(l==r)
{
for(int i=L;i<=R;i++)
{
ans[s[i]]=a[l].val;
}
return ;
}
int mid=(l+r)>>1;
for(int i=l;i<=mid;i++)
{
add(a[i].x,a[i].y,1);
}
int sl=L,sr=R;
for(int i=L;i<=R;i++)
{
int res=query(q[s[i]].a,q[s[i]].b,q[s[i]].c,q[s[i]].d);
if(res>=q[s[i]].k)
{
ql[sl++]=s[i];
}
else
{
qr[sr--]=s[i];
q[s[i]].k-=res;
}
}
for(int i=L;i<sl;i++)
{
s[i]=ql[i];
}
for(int i=sr+1;i<=R;i++)
{
s[i]=qr[i];
}
for(int i=l;i<=mid;i++)
{
add(a[i].x,a[i].y,-1);
}
solve(l,mid,L,sl-1),solve(mid+1,r,sr+1,R);
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
for(int j=1;j<=n;j++)
{
num++;
scanf("%d",&a[num].val);
a[num].x=i,a[num].y=j;
}
}
sort(a+1,a+1+num,cmp);
for(int i=1;i<=m;i++)
{
cnt++;
scanf("%d%d%d%d%d",&q[cnt].a,&q[cnt].b,&q[cnt].c,&q[cnt].d,&q[cnt].k);
s[i]=i;
}
solve(1,num,1,cnt);
for(int i=1;i<=m;i++)
{
printf("%d\n",ans[i]);
}
}