题意:给你一个序列,有翻转区间和查询操作,每次查询一段区间里面的所有字符能组成的不同回文串的个数。 字符集大小 = 10 , n , q ≤ 1 0 5 n,q\leq 10^5 n,q≤105 。
solution:
splay 好题。对于每个节点,用桶记录每种字符出现的次数,区间翻转就把 l 的前驱和 r 的后驱分别提到根节点和根的右子树。问题就转化为求 [l,r] 中各个字符出现的次数。最后用组合数算一下不同回文串个数即可。注意 rotate 要赋 fa 指针, build 函数要传实参和赋 fa 指针。同时 rotate 后不要忘记更新根节点。
时间复杂度 O ( S n l o g n ) O(Snlogn) O(Snlogn) 。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int mx=1e5+5;
const int mod=1e9+7;
struct node{
node *ch[2];
node *fa;
int cnt[10];
int siz;
int key;
int lazy;
}tree[mx];
node *NIL,*Root,*ncnt;
int n,m;
ll fac[mx],inv[mx];
char s[mx];
ll fpow(ll x,ll y) {
ll mul(1);
for(;y;y>>=1) {
if(y&1) mul=mul*x%mod;
x=x*x%mod;
}
return mul;
}
ll C(ll x,ll y) {
return fac[x]*inv[y]%mod*inv[x-y]%mod;
}
void Init() {
NIL=&tree[0];
ncnt=&tree[1];
Root=NIL->ch[0]=NIL->ch[1]=NIL->fa=NIL;
}
void PushUp(node *rt) {
rt->siz=rt->ch[0]->siz+rt->ch[1]->siz+1;
for(int i=0;i<=9;i++) {
rt->cnt[i]=rt->ch[0]->cnt[i]+rt->ch[1]->cnt[i]+('a'+i==s[rt->key]);
}
}
void PushDown(node *rt) {
if(!rt->lazy) return;
swap(rt->ch[0],rt->ch[1]);
rt->ch[0]->lazy^=1; rt->ch[1]->lazy^=1;
rt->lazy=0;
}
node *Newnode(int val) {
node *p=ncnt++;
p->ch[0]=p->ch[1]=p->fa=NIL;
p->siz=1;
if(val>=1&&val<=n) p->cnt[s[val]-'a']++,p->key=val;
return p;
}
void Rotate(node *x) {
node *y=x->fa;
int d=(x==y->ch[0]);
x->fa=y->fa;
if(y->fa!=NIL) y->fa->ch[y->fa->ch[1]==y]=x;
y->ch[!d]=x->ch[d];
if(x->ch[d]!=NIL) x->ch[d]->fa=y;
x->ch[d]=y;
y->fa=x;
if(y==Root) Root=x;
PushUp(y);
PushUp(x);
}
void Splay(node *x,node *rt) {
node *y,*z;
while(x->fa!=rt) {
y=x->fa;
z=y->fa;
if(z==rt) {
Rotate(x);
}
else {
if((x==y->ch[0])^(y==z->ch[0])) {
Rotate(x);
}
else {
Rotate(y);
}
Rotate(x);
}
}
}
void Build(node *&rt,node *fa,int l,int r) {
if(l>r) return;
int mid=l+r>>1;
rt=Newnode(mid); rt->fa=fa;
Build(rt->ch[0],rt,l,mid-1);
Build(rt->ch[1],rt,mid+1,r);
PushUp(rt);
}
node *Select(node *rt,int k) {
if(rt==NIL) return NIL;
PushDown(rt);
if(k<=rt->ch[0]->siz) return Select(rt->ch[0],k);
else if(k<=rt->ch[0]->siz+1) return rt;
else return Select(rt->ch[1],k-rt->ch[0]->siz-1);
}
ll Query(int l,int r) {
node *x=Select(Root,l),*y=Select(Root,r+2);
Splay(x,NIL),Splay(y,Root);
node *z=Root->ch[1]->ch[0];
ll mul(1); int cnt(0),cnt2(0);
for(int i=0;i<=9;i++) {
if(z->cnt[i]%2==1) cnt++;
cnt2+=z->cnt[i]/2;
}
if(cnt>1) return 0;
for(int i=0;i<=9;i++) {
if(z->cnt[i]==0) continue;
mul=mul*C(cnt2,z->cnt[i]/2)%mod;
cnt2-=z->cnt[i]/2;
}
return mul;
}
void Reverse(int l,int r) {
node *x=Select(Root,l),*y=Select(Root,r+2);
Splay(x,NIL),Splay(y,Root);
Root->ch[1]->ch[0]->lazy^=1;
}
int main() {
scanf("%d%d%s",&n,&m,s+1);
Init();
Build(Root,NIL,0,n+1);
fac[0]=1; for(int i=1;i<=n;i++) fac[i]=fac[i-1]*i%mod;
inv[n]=fpow(fac[n],mod-2);
for(int i=n;i>=1;i--) {
inv[i-1]=inv[i]*i%mod;
}
for(int i=1;i<=m;i++) {
int op,l,r;
scanf("%d%d%d",&op,&l,&r);
if(op==2) {
printf("%lld\n",Query(l,r));
}
else {
Reverse(l,r);
}
}
}