T1 魔卡少女
给出N个数,M个操作。操作有修改和询问两种,每次修改将一个数改成另一个数,每次询问一个区间的所有连续子区间的异或和。
线段树,对于二进制的每一位开一颗线段树。对于每一个区间,维护其从左开始但不到右结束,不从左开始但到右结束,既不从左开始也不到右结束的异或值为0,1的区间个数,还有总共的异或值。转移很显然。设由区间y,z合并到x,则
x.tot=y.tot^z.tot;
x.sum[0]=(y.sum[0]+z.sum[0]+
y.r[0]*z.l[0]+y.r[1]*z.l[1]+y.r[0]+z.l[0])%mo;
x.sum[1]=(y.sum[1]+z.sum[1]+
y.r[0]*z.l[1]+y.r[1]*z.l[0]+y.r[1]+z.l[1])%mo;
x.l[0]=(y.l[0]+1-y.tot+z.l[y.tot])%mo;
x.l[1]=(y.l[1]+y.tot+z.l[y.tot^1])%mo;
x.r[0]=(z.r[0]+1-z.tot+y.r[z.tot])%mo;
x.r[1]=(z.r[1]+z.tot+y.r[z.tot^1])%mo;
这里的^表示异或。
然后发现这样打的常数巨大,容易被卡。于是发现sum0根本就没有用,扔了。l0,r0虽然要用来转移,但是可以算出来,因为l0+l1=r0+r1=区间长度-1,于是也可以扔掉,然后合并就变成了
x.tot=y.tot^z.tot;
int l0=lenl-y.r,r0=lenr-z.l;
x.sum=(y.sum+z.sum+l0*z.l+y.r*r0+y.r+z.l)%mo;
x.l=y.l+y.tot;
if (y.tot) x.l+=r0;
else x.l+=z.l;
x.r=z.r+z.tot;
if (z.tot) x.r+=l0;
else x.r+=y.r;
其中lenl,lenr分别表示左右区间的长度。
这样就可以过了。
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define N 100005
#define ll long long
#define mo 100000007
using namespace std;
struct note{
ll sum;
int l,r,tot;
}t[10][N*3],ans;
char ch[1];
int n,m,x,y,a[N],ansl;
void merge(note &x,note y,note z,int lenl,int lenr){
x.tot=y.tot^z.tot;
int l0=lenl-y.r,r0=lenr-z.l;
x.sum=(y.sum+z.sum+l0*z.l+y.r*r0+y.r+z.l)%mo;
x.l=y.l+y.tot;if (y.tot) x.l+=r0;else x.l+=z.l;
x.r=z.r+z.tot;if (z.tot) x.r+=l0;else x.r+=y.r;
}
void build(int v,int l,int r,int x) {
if (l==r) {
t[x][v].tot=((a[l]&(1<<x))>0);return;
}
int m=(l+r)/2;
build(v*2,l,m,x);build(v*2+1,m+1,r,x);
merge(t[x][v],t[x][v*2],t[x][v*2+1],m-l,r-m-1);
}
void change(int v,int l,int r,int x,int y,int z) {
if (l==r) {
t[z][v].tot=y;return;
}
int m=(l+r)/2;
if (x<=m) change(v*2,l,m,x,y,z);
else change(v*2+1,m+1,r,x,y,z);
merge(t[z][v],t[z][v*2],t[z][v*2+1],m-l,r-m-1);
}
void find(int v,int l,int r,int x,int y,int z) {
if (l==x&&r==y) {
if (ans.tot<0) ans=t[z][v];
else merge(ans,ans,t[z][v],l-ansl-1,r-l);
return;
}
int m=(l+r)/2;
if (y<=m) find(v*2,l,m,x,y,z);
else if (x>m) find(v*2+1,m+1,r,x,y,z);
else {
find(v*2,l,m,x,m,z);find(v*2+1,m+1,r,m+1,y,z);
}
}
int main() {
scanf("%d",&n);
fo(i,1,n) scanf("%d",&a[i]);
fo(i,0,9) build(1,1,n,i);
for(scanf("%d",&m);m;m--) {
scanf("%s%d%d",ch,&x,&y);
if (ch[0]=='M') fo(i,0,9) change(1,1,n,x,((y&(1<<i))>0),i);
else {
ll sum=0;
fo(i,0,9) {
ans.tot=-1;ansl=x;
find(1,1,n,x,y,i);
sum=(sum+(ans.sum+
ans.l+ans.r+ans.tot)
*(1<<i)%mo)%mo;
}
printf("%lld\n",sum);
}
}
}