poj 3376 Finding Palindromes(扩展kmp+trie)
题意:给出n个字符串,问这n个字符串两两链接(一共有n^2中连接方法),组成的所有的字符串中,有多少个回文串。
好题!!!
解题思路:对于两个串连接是否是回文串,我们应该怎样去判断了?假如我们把其中一个翻转,若此时,短的那个串是长的那个的前缀,而长的那个串后面剩余的后缀恰好是个回文串,那这两个串连起来就是个回文串了。比如abc和abacba连接,我们把后一个串翻转,得到abcaba,abc为其前缀,而aba是个回文串,那么连起来就是个回文串了(这个规律找到就好办了)。那我们就把所有的串插入到trie中,然后再用所有的反串去匹配就行了。匹配的过程中,走到任意一个节点,而这个节点有可能是若干个串的结尾,那么此时我们就要判反串匹配位置下面剩余的部分是否回文。如果是的,ans就加上以这个节点为结尾的原串的个数(这个插入的时候就可以统计进去了)。如果走完了,还没走到叶子节点,那么就要看走到的节点下的子树(其实是以前面走过的路径为前缀的字符串剩下的一些后缀)有多少是回文的了(这个先预处理所有的串的后缀有哪些是回文的,然后在插入的时候统计到节点上)。剩下来一个问题就是如何在线性的时间内(或许o(nlogn)也可以吧,但我们有线性的算法,岂不更好?),这里我只是说,用扩展kmp能很合适。具体如何实现,留个小思考给大家(很简单的啦)。。
- #include<stdio.h>
- #include<algorithm>
- #include<string.h>
- #define ll __int64
- #include<vector>
- using namespace std ;
- const int maxn = 2222222 ;
- char vec[maxn] ;
- int g[maxn] , nxt[maxn] ;
- bool li[maxn] ;
- int ok[maxn] , p[maxn] ;
- void get_p (const char *T){
- int len=strlen(T),a=0;
- int i , k ;
- p[0]=len;
- while(a<len-1 && T[a]==T[a+1]) a++;
- p[1]=a;
- a=1;
- for( k=2;k<len;k++){
- int fuck=a+p[a]-1,L=p[k-a];
- if( (k-1)+L >= fuck){
- int j = (fuck-k+1)>0 ? (fuck-k+1) : 0;
- while(k+j<len && T[k+j]==T[j]) j++;
- p[k]=j;
- a=k;
- }
- else p[k]=L;
- }
- }
- void match ( char *s , char *s1 ) {
- int len = strlen ( s ) , len1 = strlen ( s1 ) ;
- int i = 0 , k , j = 0 , a ;
- while ( i < len && j < len1 && s[i] == s1[j] ) i ++ , j ++ ;
- ok[0] = j ;
- a = 0 ;
- for ( k = 1 ; k < len ; k ++ ) {
- int fuck = a + ok[a] - 1 , l = p[k-a] ;
- if ( k + l - 1 >= fuck ) {
- int j = ( fuck - k + 1 ) > 0 ? ( fuck - k + 1 ) : 0 ;
- while ( k + j < len && j < len1 && s[k+j] == s1[j] ) j ++ ;
- ok[k] = j ;
- a = k ;
- }
- else ok[k] = l ;
- }
- }
- int tot = 0 , c[26][maxn] , cnt[maxn] , val[maxn] ;
- int new_node () {
- int i ;
- for ( i = 0 ; i < 26 ; i ++ ) c[i][tot] = 0 ;
- cnt[tot] = val[tot] = 0 ;
- return tot ++ ;
- }
- void insert ( char *s ) {
- int len = strlen ( s ) , i , now = 0 ;
- for ( i = 0 ; i < len ; i ++ ) {
- int k = s[i] - 'a' ;
- if ( !c[k][now] ) c[k][now] = new_node () ;
- now = c[k][now] ;
- if ( i + 1 < len && ok[i+1] == len - i - 1 ) {
- cnt[now] ++ ;
- }
- }
- cnt[now] ++ ;
- val[now] ++ ;
- }
- ll ans = 0 ;
- void cal ( int len ) {
- int j , i , now = 0 ;
- li[len+1] = 1 ;
- // printf ( "len = %d\n" , len ) ;
- // for ( i = 1 ; i <= len ; i ++ ) printf ( "%d " , nxt[i] ) ; puts ( "" ) ;
- for ( j = 1 ; j <= len ; j ++ ) {
- // printf ( "j = %d , ans = %I64d\n" , j , ans ) ;
- if ( li[j] ) now = 0 ;
- int k = vec[j] - 'a' ;
- if ( !c[k][now] ) {
- now = 0 ;
- // printf ( "nxt[%d] = %d\n" , j , nxt[j] ) ;
- j = nxt[j] - 1 ;
- continue ;
- }
- now = c[k][now] ;
- if ( !li[j+1] && g[j+1] ) ans += (ll) val[now] ;
- // printf ( "j = %d , now = %d\n" , j , now ) ;
- if ( li[j+1] ) {
- // if ( j == 10 ) printf ( "cnt[%d] = %d\n" , now , cnt[now] ) ;
- ans += (ll) cnt[now] ;
- now = 0 ;
- }
- }
- }
- char s1[maxn] , s[maxn] ;
- int main () {
- int n , i , j , k ;
- while ( scanf ( "%d" , &n ) != EOF ) {
- tot = 0 ;
- new_node () ;
- int t = 0 ;
- for ( i = 1 ; i <= n ; i ++ ) {
- scanf ( "%d%s" , &j , s ) ;
- strcpy ( s1 , s ) ;
- int len = strlen ( s ) ;
- reverse ( s1 , s1 + len ) ;
- get_p ( s1 ) ;
- match ( s , s1 ) ;
- insert ( s ) ;
- get_p ( s ) ;
- match ( s1 , s ) ;
- li[t+1] = 1 ;
- for ( j = 0 ; j < len ; j ++ ) {
- if ( ok[j] == len - j ) g[++t] = 1 ;
- else g[++t] = 0 ;
- vec[t] = s1[j] ;
- if ( j ) li[t] = 0 ;
- }
- }
- // for ( i = 1 ; i < t ; i ++ ) printf ( "%d " , cnt[i] ) ; puts ( "" ) ;
- int last = t + 1 ;
- // for ( i = 1 ; i <= t ; i ++ ) printf ( "%d " , li[i] ) ; puts ( "" ) ;
- for ( i = t ; i >= 1 ; i -- ) {
- nxt[i] = last ;
- if ( li[i] ) last = i ;
- }
- // for ( i = 1 ; i <= t ; i ++ ) printf ( "%d " , nxt[i] ) ; puts ( "" ) ;
- ans = 0 ;
- cal ( t ) ;
- printf ( "%I64d\n" , ans ) ;
- }
- }