说在前面
感觉…好名字都被库文件取完了
一开始变量名是exp,本地编译和math库冲突了,换成index本地编译过了
然后交了一发CE了,看编译信息发现又string库冲突了…
= =…简直可恶
题目
题目大意
给定一个字符串(长度不超过100000),请求出符合条件的子序列数
条件:
1. 该子序列是一个回文序列
2. 该序列的位置不可以连续(比如aabaa中,aba不合法,因为它是位置连续的)
输入输出格式
输入格式:
输入一行,包含一个字符串
输出格式:
输出合法子序列数对1e+7取模的结果
解法
先考虑如何计算所有的回文序列,算出来之后减去不合法的就是答案。
对于每个位置,都统计一遍以该位置为中心的对称位置数量(设为k),那么以该位置为中心的回文序列有
2k−1
2
k
−
1
个。
不合法的回文序列是连续的,也就是回文子串,可以用manacher在
Θ(n)
Θ
(
n
)
内搞定。
现在难点就是如何计算每个位置的k值。暴力算法很简单,对于每个位置都 Θ(n) Θ ( n ) 的对比左右字符是否相同即可,相同就k++。把这个操作写成式子的形式,就是这样: ∑i=0[a(pos−i)==a(pos+i)] ∑ i = 0 [ a ( p o s − i ) == a ( p o s + i ) ] ,这是一个卷积的形式。也就是说,如果把字符串中的a写成1,b写成0,会构成一个字符串长度项的多项式,这个多项式自乘得到的多项式中,第i项的系数也就是第i个位置的 ka k a 值。同理把b写成1,a写成0,可以得到每个位置的 kb k b 值,于是k值就知道了。
上面的多项式乘法显然可以用FFT优化到 Nlog2N N l o g 2 N 级别,即可通过此题
下面是自带大常数的代码
/**************************************************************
Problem: 3160
User: Izumihanako
Language: C++
Result: Accepted
Time:2252 ms
Memory:11892 kb
****************************************************************/
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;
const double PI = 3.1415926535897932384626 ;
const int mmod = 1000000007 ;
char ss[100005] , Mss[200005] ;
int lens , Mlens , N , expos[262145] , Index[200005] , pal[200005] ;
struct Complex_{
double x , y ;
} a[262145] , b[262145] ;
typedef Complex_ cpx ;
cpx operator + ( const cpx &A , const cpx &B ){
return ( cpx ){ A.x + B.x , A.y + B.y } ;
}
cpx operator - ( const cpx &A , const cpx &B ){
return ( cpx ){ A.x - B.x , A.y - B.y } ;
}
cpx operator * ( const cpx &A , const cpx &B ){
return ( cpx ){ A.x*B.x - A.y*B.y , A.x*B.y + A.y*B.x } ;
}
cpx operator / ( const cpx &A , const double &p ){
return ( cpx ){ A.x / p , A.y / p } ;
}
void FFT( cpx *a , int fix ){
for( int i = 1 ; i < N ; i ++ )
if( expos[i] > i ) swap( a[i] , a[ expos[i] ] ) ;
for( int i = 2 ; i <= N ; i <<= 1 ){
cpx id = ( cpx ){ cos( 2 * PI * fix / i ) , sin( 2 * PI * fix / i ) } ;
for( int j = 0 ; j < N ; j += i ){
cpx W = ( cpx ){ 1 , 0 } ;
for( int k = j ; k < j + i/2 ; k ++ ){
cpx u = a[k] , v = a[k+i/2] * W ;
a[k] = u + v ;
a[k+i/2] = u - v ;
W = W * id ;
}
}
}
if( fix == -1 )
for( int i = 0 ; i < N ; i ++ )
a[i] = a[i] / N ;
}
void manacher(){
for( int i = 1 , id = 0 , mx = 0 ; i < Mlens ; i ++ ){
if( mx > i ) pal[i] = min( mx - i , pal[ 2*id-i ] ) ;
else pal[i] = 1 ;
while( Mss[ i+pal[i] ] == Mss[ i-pal[i] ] ) pal[i] ++ ;
if( i + pal[i] > mx ){
mx = i + pal[i] ;
id = i ;
}
}
}
long long s_pow( long long x , int b ){
long long rt = 1 ;
while( b ){
if( b&1 ) rt = rt * x %mmod ;
x = x * x %mmod ; b >>= 1 ;
//printf( "%d\n" , b ) ;
} return rt ;
}
void init(){
Mss[0] = '+' , Mss[1] = '#' ; Mlens = 2 ;
for( int i = 0 ; i < lens ; i ++ )
Mss[Mlens++] = ss[i] , Mss[Mlens++] = '#' ;
for( N = 1 ; N < 2 * lens - 2 ; N <<= 1 ) ;
for( int i = 1 , x = 0 ; i < N ; i ++ ){
int tmp = N >> 1 ;
while( tmp&x ) x ^= tmp , tmp >>= 1 ;
x ^= tmp ; expos[i] = x ;
}
}
void solve(){
//puts( "Entre solve" ) ;
for( int i = 0 ; i < lens ; i ++ ){
a[i] = ( cpx ){ ss[i] == 'a' ? 1.0 : 0 , 0 } ;
b[i] = ( cpx ){ ss[i] == 'b' ? 1.0 : 0 , 0 } ;
}
FFT( a , 1 ) ; FFT( b , 1 ) ;
//puts( "DFT end" ) ;
for( int i = 0 ; i < N ; i ++ )
a[i] = a[i] * a[i] + b[i] * b[i] ;
FFT( a , -1 ) ;
//puts( "IDFT end" ) ;
for( int i = 2 * lens - 2 ; i >= 0 ; i -- )
Index[i] = ( ( a[i].x + 0.5 ) + 1 ) / 2 /*, printf( "Index[%d] = %d\n" , i , Index[i] )*/ ;
manacher() ;
//puts( "Manacher end" ) ;
long long ans = 0 ;
for( int i = 2 ; i < Mlens ; i ++ )
ans = ( ans + ( s_pow( 2 , Index[i-2] ) - 1 ) - pal[i]/2 )%mmod ;
printf( "%lld" , ans ) ;
}
int main(){
scanf( "%s" , ss ) ; lens = strlen( ss ) ;
init() ; solve() ;
}