如赛后官方题解所说,本题应考虑相反情况,即不与目标矩形相交的所有矩形数。显然矩形总数为n(n-1)/2。易知目标矩形上方的点数为n-u,右方的点数为n-r,左方的点数为l-1,下方点数为d-1. 不与目标矩形相交的矩形就是分别由这些区域的点构成。(点即原题所述的单位正方形,即一个格子)。
依次扣除这些点构成的矩形后,显然左上方,左下方,右上方,右下方的矩形被重复计算了一边。所以这题的关键就在于求出这些区域的点数。具体来说,就是对于查询l, d, r, u, 分别求出满足
(x∈[1,l−1],y∈[1,d−1])
(左下方),
(x∈[r+1,n],y∈[1,d−1])
(右下方),
(x∈[1,l−1],y∈[u+1,n])
(左上方),
(x∈[r+1,n],y∈[u+1,n])
(右上方),
的点的总数。也就是实现一个二维区间的离线查询。
要想不TLE,就要用到一种神奇的数据结构:主席树。主席树是一种特殊的线段树。在这道题里,相当于建立了n棵线段树,第i棵线段树记录的是原数组前i项构成的子数组的信息(即a[1]到a[i]),线段树的每一个结点三个值,l,r,dat,l和r相当熟悉,就是该结点的左右儿子的位置(在树的数组T中)。而就如同普通的线段树一样,每一个结点都会对应到一个区间[L,R](为了避免跟前面的l,r混淆,采用大写)。那么结点的记录值dat就是表示原来的子数组(1..i)中有多少个数出现在区间[l,r], 即
这样的结构就能实现二维区间查询。这里我们也用到了前缀和的思想。如果我们想统计出 (x∈[xl,xr],y∈[yl,yr]) 的点的个数,那么我们可以分别对tree[xl - 1], tree[xr]两棵树查询 y∈[yl,yr] 的点的个数,再对两树相减就是结果。
但等等,建n棵线段树,不会TLE + MLE吗?而主席树的神奇之处就在这个地方。我们考虑相邻的两棵线段树,tree[i-1] 和 tree[i]。比较一下,我们发现,tree[i] 仅仅是在 tree[i-1] 这棵树的基础上更新了新的一个值a[i]。所以,tree[i]较于tree[i-1]不同的地方只有a[i]所在的链的结点上,别的结点完全一致。
举个例子。
考虑这一棵线段树,假如我们要新加一个值(a[i]=4), 那么线段树要改的结点应该是4, [4,5], [4,7], [0,7]。这几个结点的dat值全部加一。其余结点不发生改变。
所以,我们在依次建n棵树的时候,只需要将发生改变的结点新建一份加进树的数组中,其余不发生改变的结点两棵树共用。
过程如下:
1. 新建一棵表示区间范围是[1,n]的空树,数组T依次存放树的结点,sz(size of tree)表示当前已经用过的树的结点数。
2.依次用原数组中的值a[i]更新树。注意到根结点(表示区间[1,n])是必然更新的,我们用数组rt存在根结点的位置,rt[i]表示第i棵树的根结点位置。在更新的时候,变量last记录当前要更新结点对应于上一棵树的结点的位置。区间从顶往下更新,也就是首先更新根结点rti, 那么上一棵树的对应结点就是rt[i-1]. 对区间二分,看a[i](要更新的值)在哪个半区间,就往哪个子结点更新,而另一个子结点不需要更新,直接连上当前更新的结点即可。
见代码:
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 200019;
typedef long long ll;
int rt[N],a[N];
int n,q;
struct tree
{
int l, r, dat;
tree(int l = 0, int r = 0, int dat = 0):l(l),r(r),dat(dat){}
}T[N * 20];
int sz; // tree size
void update(int &o, int l, int r, int last, int ele)
//update inteval[l, r] (add ele) the index of new node will be stored in o
{
o = ++sz;
T[o].l = T[last].l;
T[o].r = T[last].r;
T[o].dat = T[last].dat + 1;
if(l == r) return;
int mid = (l + r) >> 1;
//printf("T[%d].l = %d, .r = %d, .dat = %d, last = %d, ele = %d, mid = %d\n",o, T[o].l, T[o].r, T[o].dat, last, ele, mid);
if (ele <= mid) update(T[o].l, l, mid, T[last].l, ele);
else update(T[o].r, mid + 1, r, T[last].r, ele);
}
int query(int t1, int t2, int l, int r, int ql, int qr)//query inteval[ql, qr] search interval [l, r] tree node t1 t2
{
if (r < ql || l > qr) return 0;
if (ql <= l && r <= qr) return T[t2].dat - T[t1].dat;
int mid = (l + r) >> 1;
return query(T[t1].l, T[t2].l, l, mid, ql, qr) + query(T[t1].r, T[t2].r, mid + 1, r, ql, qr);
}
int query_client(int xl, int xr, int yl, int yr)
{
if (xl > xr || yl > yr) return 0;
return query(rt[xl - 1], rt[xr], 1, n, yl, yr);
}
int main()
{
scanf("%d%d",&n,&q);
sz = 0;
for (int i = 1; i <= n; i++)
{
scanf("%d", a + i);
update(rt[i], 1, n, rt[i - 1], a[i]);
}
for (int i = 0; i < q; i++)
{
int l,d,r,u;
scanf("%d%d%d%d",&l,&d,&r,&u);
ll ans = 1LL * n * (n - 1) / 2;
ll x = l - 1;
ans -= x * (x - 1) / 2;
x = d - 1;
ans -= x * (x - 1) / 2;
x = n - r;
ans -= x * (x - 1) / 2;
x = n - u;
ans -= x * (x - 1) / 2;
x = query_client(1, l - 1, 1, d - 1);
ans += x * (x - 1) / 2;
x = query_client(r + 1, n, 1, d - 1);
ans += x * (x - 1) / 2;
x = query_client(1, l - 1, u + 1, n);
ans += x * (x - 1) / 2;
x = query_client(r + 1, n, u + 1, n);
ans += x * (x - 1) / 2;
printf("%lld\n",ans);
}
}