题目如下:
Description
对于序列A,它的逆序对数定义为满足i<j,且Ai>Aj的数对(i,j)的个数。给1到n的一个排列,按照某种顺序依次删除m个元素,你的任务是在每次删除一个元素之前统计整个序列的逆序对数。
Input
输入第一行包含两个整数n和m,即初始元素的个数和删除的元素个数。以下n行每行包含一个1到n之间的正整数,即初始排列。以下m行每行一个正整数,依次为每次删除的元素。
Output
输出包含m行,依次为删除每个元素之前,逆序对的个数。
HINT
N<=100000 M<=50000
为了方便,将删除操作倒着执行,变成插入操作。
这样就变成每次插入一个元素,求有多少x比它小,值比它大的数字,和有多少x比它大,值比它小的数字。
有x,y,t三个维度。y表示x位置的数字,t表示这个数字是何时被插入进来的。
这是一个三维偏序问题,可以用树套树或者CDQ分治来做。这里先尝试CDQ分治。
我们按x维度进行排序(这个读入过程中就直接排好了),对t维度进行分治。
分治过程中,我们要处理所有t<=mid的数对t>=mid+1的答案的影响。
为了降低复杂度,首先把所有t<=mid的数都放在数组的前半部分,t>=mid+1的都放在后半部分,两半部分以内的x依然是有序的。
对于每个t>=mid+1,找到x比它更小,y比它更大的数,在树状数组上更新,然后查询,找x比它更大,y比它更小的数,同样这么做。
然后递归处理(l,mid),(mid+1,r)即可。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<string>
#define ll long long
using namespace std;
int n,m;
ll tree[100111];
int cnt=0;
ll ans[100010],ansl[100010],ansr[100010];
struct node
{
ll x,y,t;
node(){t=0;}
}a[100111],temp[100111];
int antiid[100111];
void add(int x,int k)
{
for(x;x<=n;x+=x&-x)
tree[x]+=k;
}
int query(int x)
{
ll res=0;
for(x;x>0;x-=x&-x)
res+=tree[x];
return res;
}
void cdq(int l,int r)
{
//cout<<l<<" "<<r<<endl;
if(l==r)return;
int mid=l+r>>1,lp=l,rp=mid+1;
for(int i=l;i<=r;++i)
if(a[i].t<=mid)temp[lp++]=a[i];
else temp[rp++]=a[i];
for(int i=l;i<=r;++i)
a[i]=temp[i];
//对于(维度t)右边的每一个点,作为询问来处理
//找到所有x更小,y更大的插入操作
int j=l;
for(int i=mid+1;i<=r;++i)
{
for(;j<=mid&&(a[j].x<a[i].x);++j)add(a[j].y,1);
//j-l表示所有x更小的,query表示有几个x更小,且y更小的
ansl[a[i].t]+=(j-l)-query(a[i].y);
}
for(int i=l;i<j;++i)add(a[i].y,-1);
//找到所有x更大,y更小的插入操作
j=mid;
for(int i=r;i>=mid+1;--i)
{
for(;j>=l&&(a[j].x>a[i].x);--j)add(a[j].y,1);
//query表示有几个x更大,且y更小的
ansr[a[i].t]+=query(a[i].y-1);
}
for(int i=j+1;i<=mid;++i)add(a[i].y,-1);
cdq(l,mid);cdq(mid+1,r);
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i)
{
scanf("%d",&a[i].y);
antiid[a[i].y]=a[i].x=i;
}
int tim=n;
for(int i=1;i<=m;++i)
{
int k;
scanf("%d",&k);
a[antiid[k]].t=tim--;
}
for(int i=1;i<=n;++i)
if(!a[i].t)a[i].t=tim--;
cdq(1,n);
for(int i=1;i<=n;++i)
ans[i]=ans[i-1]+ansl[i]+ansr[i];
for(int i=n;i>=n-m+1;--i)
printf("%lld\n",ans[i]);
return 0;
}