[BZOJ3160]-万径人踪灭-manacher+FFT

本文介绍了解决BZOJ3160题目的方法,该题要求计算给定字符串中所有符合条件的回文子序列的数量。文章详细阐述了利用FFT算法优化多项式乘法的过程,并给出了具体的实现代码。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

说在前面

感觉…好名字都被库文件取完了
一开始变量名是exp,本地编译和math库冲突了,换成index本地编译过了
然后交了一发CE了,看编译信息发现又string库冲突了…

= =…简直可恶


题目

BZOJ3160传送门

题目大意

给定一个字符串(长度不超过100000),请求出符合条件的子序列数
条件:
1. 该子序列是一个回文序列
2. 该序列的位置不可以连续(比如aabaa中,aba不合法,因为它是位置连续的)

输入输出格式

输入格式:
输入一行,包含一个字符串

输出格式:
输出合法子序列数对1e+7取模的结果


解法

先考虑如何计算所有的回文序列,算出来之后减去不合法的就是答案。
对于每个位置,都统计一遍以该位置为中心的对称位置数量(设为k),那么以该位置为中心的回文序列有 2k1 2 k − 1 个。
不合法的回文序列是连续的,也就是回文子串,可以用manacher在 Θ(n) Θ ( n ) 内搞定。

现在难点就是如何计算每个位置的k值。暴力算法很简单,对于每个位置都 Θ(n) Θ ( n ) 的对比左右字符是否相同即可,相同就k++。把这个操作写成式子的形式,就是这样: i=0[a(posi)==a(pos+i)] ∑ i = 0 [ a ( p o s − i ) == a ( p o s + i ) ] ,这是一个卷积的形式。也就是说,如果把字符串中的a写成1,b写成0,会构成一个字符串长度项的多项式,这个多项式自乘得到的多项式中,第i项的系数也就是第i个位置的 ka k a 值。同理把b写成1,a写成0,可以得到每个位置的 kb k b 值,于是k值就知道了。

上面的多项式乘法显然可以用FFT优化到 Nlog2N N l o g 2 N 级别,即可通过此题


下面是自带大常数的代码

/**************************************************************
    Problem: 3160
    User: Izumihanako
    Language: C++
    Result: Accepted
    Time:2252 ms
    Memory:11892 kb
****************************************************************/
 
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;
 
const double PI = 3.1415926535897932384626 ;
const int mmod = 1000000007 ;
char ss[100005] , Mss[200005] ;
int lens , Mlens , N , expos[262145] , Index[200005] , pal[200005] ;
struct Complex_{
    double x , y ;
} a[262145] , b[262145] ;
 
typedef Complex_ cpx ;
cpx operator + ( const cpx &A , const cpx &B ){
    return ( cpx ){ A.x + B.x , A.y + B.y } ;
}
cpx operator - ( const cpx &A , const cpx &B ){
    return ( cpx ){ A.x - B.x , A.y - B.y } ;
}
cpx operator * ( const cpx &A , const cpx &B ){
    return ( cpx ){ A.x*B.x - A.y*B.y , A.x*B.y + A.y*B.x } ;
}
cpx operator / ( const cpx &A , const double &p ){
    return ( cpx ){ A.x / p , A.y / p } ;
}
 
void FFT( cpx *a , int fix ){
    for( int i = 1 ; i < N ; i ++ )
        if( expos[i] > i ) swap( a[i] , a[ expos[i] ] ) ;
    for( int i = 2 ; i <= N ; i <<= 1 ){
        cpx id = ( cpx ){ cos( 2 * PI * fix / i ) , sin( 2 * PI * fix / i ) } ;
        for( int j = 0 ; j < N ; j += i ){
            cpx W = ( cpx ){ 1 , 0 } ;
            for( int k = j ; k < j + i/2 ; k ++ ){
                cpx u = a[k] , v = a[k+i/2] * W ;
                a[k] = u + v ;
                a[k+i/2] = u - v ;
                W = W * id ;
            }
        }
    }
    if( fix == -1 )
    for( int i = 0 ; i < N ; i ++ )
        a[i] = a[i] / N ;
}
 
void manacher(){
    for( int i = 1 , id = 0 , mx = 0 ; i < Mlens ; i ++ ){
        if( mx > i ) pal[i] = min( mx - i , pal[ 2*id-i ] ) ;
        else         pal[i] = 1 ;
        while( Mss[ i+pal[i] ] == Mss[ i-pal[i] ] ) pal[i] ++ ;
        if( i + pal[i] > mx ){
            mx = i + pal[i] ;
            id = i ;
        }
    }
}
 
long long s_pow( long long x , int b ){
    long long rt = 1 ;
    while( b ){
        if( b&1 ) rt = rt * x %mmod ;
        x = x * x %mmod ; b >>= 1 ;
        //printf( "%d\n" , b ) ;
    } return rt ;
}
 
void init(){
    Mss[0] = '+' , Mss[1] = '#' ; Mlens = 2 ;
    for( int i = 0 ; i < lens ; i ++ )
        Mss[Mlens++] = ss[i] , Mss[Mlens++] = '#' ;
     
    for( N = 1 ; N < 2 * lens - 2 ; N <<= 1 ) ;
    for( int i = 1 , x = 0 ; i < N ; i ++ ){
        int tmp = N >> 1 ;
        while( tmp&x ) x ^= tmp , tmp >>= 1 ;
        x ^= tmp ; expos[i] = x ;
    }
}
 
void solve(){
    //puts( "Entre solve" ) ;
    for( int i = 0 ; i < lens ; i ++ ){
        a[i] = ( cpx ){ ss[i] == 'a' ? 1.0 : 0 , 0 } ;
        b[i] = ( cpx ){ ss[i] == 'b' ? 1.0 : 0 , 0 } ;
    }
    FFT( a , 1 ) ; FFT( b , 1 ) ;
    //puts( "DFT end" ) ;
    for( int i = 0 ; i < N ; i ++ )
        a[i] = a[i] * a[i] + b[i] * b[i] ;
    FFT( a , -1 ) ;
    //puts( "IDFT end" ) ;
    for( int i = 2 * lens - 2 ; i >= 0 ; i -- )
        Index[i] = ( ( a[i].x + 0.5 ) + 1 ) / 2 /*, printf( "Index[%d] = %d\n" , i , Index[i] )*/ ;
    manacher() ;
    //puts( "Manacher end" ) ;
 
    long long ans = 0 ;
    for( int i = 2 ; i < Mlens ; i ++ )
        ans = ( ans + ( s_pow( 2 , Index[i-2] ) - 1 ) - pal[i]/2 )%mmod ;
    printf( "%lld" , ans ) ;
     
}
 
int main(){
    scanf( "%s" , ss ) ; lens = strlen( ss ) ;
    init() ; solve() ;
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值