【题目链接】 http://acm.hdu.edu.cn/showproblem.php?pid=5737
【题目大意】
给出两个序列a和b,要求实现两个操作:
1. 将a序列的一个区间中的所有数改成同一个数
2. 查询一个区间内a数组中大于相同下标b数组中的数的数。
【题解】
考虑到b数组是不变的,可以在归并树上预处理出b数组中每个元素在树的左右儿子中的排名,在归并树建立时,可以求出每个区间ai>bi的个数,在发生区间修改的时候,在根节点的有序的b数组中二分查找修改值的排名,将信息下传就可以统计每个区间ai>bi的数目,由于操作具有区间性质,因此可以打延迟标记。查询时下传标记统计答案即可。
【代码】
#include <cstdio>
#include <algorithm>
using namespace std;
const int N=100010,M=300000,U=2000000,mod=1e9+7;
int T,n,m,ans,a[N],b[N],A,B,C=~(1<<31),L,R,x,Ans;
int st[M],en[M],v[M],tag[M],l[U],r[U],seq[U],cnt;
void addtag(int x,int p){v[x]=p?p-st[x]+1:0;tag[x]=p;}
void pb(int x){
if(tag[x]<0)return;
addtag(x<<1,l[tag[x]]);addtag(x<<1|1,r[tag[x]]);
tag[x]=-1;
}
void build(int x,int a,int b){
tag[x]=-1;
if(a==b){st[x]=++cnt;seq[cnt]=::b[a];en[x]=cnt;v[x]=::a[a]>=::b[a];return;}
int mid=(a+b)>>1;
build(x<<1,a,mid),build(x<<1|1,mid+1,b);
v[x]=v[x<<1]+v[x<<1|1];
int al=st[x<<1],ar=en[x<<1],bl=st[x<<1|1],br=en[x<<1|1];
st[x]=cnt+1;
while(al<=ar&&bl<=br)seq[++cnt]=seq[al]<seq[bl]?seq[al++]:seq[bl++];
while(al<=ar)seq[++cnt]=seq[al++];
while(bl<=br)seq[++cnt]=seq[bl++];
en[x]=cnt;
al=st[x<<1],bl=st[x<<1|1];
for(int i=st[x];i<=cnt;i++){
while(al<=ar&&seq[al]<=seq[i])al++;
while(bl<=br&&seq[bl]<=seq[i])bl++;
l[i]=al-1;r[i]=bl-1;
if(l[i]<st[x<<1])l[i]=0;
if(r[i]<st[x<<1|1])r[i]=0;
}
}
void change(int x,int a,int b,int p){
if(L<=a&&b<=R){addtag(x,p);return;}pb(x);
int mid=(a+b)>>1;
if(L<=mid)change(x<<1,a,mid,l[p]);
if(R>mid)change(x<<1|1,mid+1,b,r[p]);
v[x]=v[x<<1]+v[x<<1|1];
}
void ask(int x,int a,int b){
if(L<=a&&b<=R){ans+=v[x];return;}pb(x);
int mid=(a+b)>>1;
if(L<=mid)ask(x<<1,a,mid);
if(R>mid)ask(x<<1|1,mid+1,b);
v[x]=v[x<<1]+v[x<<1|1];
}
int lower(int x){
int l=st[1],r=en[1],pos=0;
while(l<=r){
int mid=(l+r)>>1;
if(seq[mid]<=x)pos=mid,l=mid+1;
else r=mid-1;
}return pos;
}
int rnd(){
A=(36969+(ans>>3))*(A&65535)+(A>>16);
B=(18000+(ans>>3))*(B&65535)+(B>>16);
return(C&((A<<16)+B))%1000000000;
}
int main(){
scanf("%d",&T);
while(T--){
scanf("%d%d%d%d",&n,&m,&A,&B);
for(int i=1;i<=n;i++)scanf("%d",a+i);
for(int i=1;i<=n;i++)scanf("%d",b+i);
cnt=Ans=ans=0; build(1,1,n);
for(int i=1;i<=m;i++){
L=rnd()%n+1,R=rnd()%n+1,x=rnd()+1;
if(L>R)swap(L,R);
if((L+R+x)&1)change(1,1,n,lower(x));
else{
ans=0; ask(1,1,n);
Ans=(1LL*ans*i+Ans)%mod;
}
}printf("%d\n",Ans);
}return 0;
}