poj 3376 Finding Palindromes(扩展kmp+trie)
题意:给出n个字符串,问这n个字符串两两链接(一共有n^2中连接方法),组成的所有的字符串中,有多少个回文串。
好题!!!
解题思路:对于两个串连接是否是回文串,我们应该怎样去判断了?假如我们把其中一个翻转,若此时,短的那个串是长的那个的前缀,而长的那个串后面剩余的后缀恰好是个回文串,那这两个串连起来就是个回文串了。比如abc和abacba连接,我们把后一个串翻转,得到abcaba,abc为其前缀,而aba是个回文串,那么连起来就是个回文串了(这个规律找到就好办了)。那我们就把所有的串插入到trie中,然后再用所有的反串去匹配就行了。匹配的过程中,走到任意一个节点,而这个节点有可能是若干个串的结尾,那么此时我们就要判反串匹配位置下面剩余的部分是否回文。如果是的,ans就加上以这个节点为结尾的原串的个数(这个插入的时候就可以统计进去了)。如果走完了,还没走到叶子节点,那么就要看走到的节点下的子树(其实是以前面走过的路径为前缀的字符串剩下的一些后缀)有多少是回文的了(这个先预处理所有的串的后缀有哪些是回文的,然后在插入的时候统计到节点上)。剩下来一个问题就是如何在线性的时间内(或许o(nlogn)也可以吧,但我们有线性的算法,岂不更好?),这里我只是说,用扩展kmp能很合适。具体如何实现,留个小思考给大家(很简单的啦)。。
#include<stdio.h>
#include<algorithm>
#include<string.h>
#define ll __int64
#include<vector>
using namespace std ;
const int maxn = 2222222 ;
char vec[maxn] ;
int g[maxn] , nxt[maxn] ;
bool li[maxn] ;
int ok[maxn] , p[maxn] ;
void get_p (const char *T){
int len=strlen(T),a=0;
int i , k ;
p[0]=len;
while(a<len-1 && T[a]==T[a+1]) a++;
p[1]=a;
a=1;
for( k=2;k<len;k++){
int fuck=a+p[a]-1,L=p[k-a];
if( (k-1)+L >= fuck){
int j = (fuck-k+1)>0 ? (fuck-k+1) : 0;
while(k+j<len && T[k+j]==T[j]) j++;
p[k]=j;
a=k;
}
else p[k]=L;
}
}
void match ( char *s , char *s1 ) {
int len = strlen ( s ) , len1 = strlen ( s1 ) ;
int i = 0 , k , j = 0 , a ;
while ( i < len && j < len1 && s[i] == s1[j] ) i ++ , j ++ ;
ok[0] = j ;
a = 0 ;
for ( k = 1 ; k < len ; k ++ ) {
int fuck = a + ok[a] - 1 , l = p[k-a] ;
if ( k + l - 1 >= fuck ) {
int j = ( fuck - k + 1 ) > 0 ? ( fuck - k + 1 ) : 0 ;
while ( k + j < len && j < len1 && s[k+j] == s1[j] ) j ++ ;
ok[k] = j ;
a = k ;
}
else ok[k] = l ;
}
}
int tot = 0 , c[26][maxn] , cnt[maxn] , val[maxn] ;
int new_node () {
int i ;
for ( i = 0 ; i < 26 ; i ++ ) c[i][tot] = 0 ;
cnt[tot] = val[tot] = 0 ;
return tot ++ ;
}
void insert ( char *s ) {
int len = strlen ( s ) , i , now = 0 ;
for ( i = 0 ; i < len ; i ++ ) {
int k = s[i] - 'a' ;
if ( !c[k][now] ) c[k][now] = new_node () ;
now = c[k][now] ;
if ( i + 1 < len && ok[i+1] == len - i - 1 ) {
cnt[now] ++ ;
}
}
cnt[now] ++ ;
val[now] ++ ;
}
ll ans = 0 ;
void cal ( int len ) {
int j , i , now = 0 ;
li[len+1] = 1 ;
// printf ( "len = %d\n" , len ) ;
// for ( i = 1 ; i <= len ; i ++ ) printf ( "%d " , nxt[i] ) ; puts ( "" ) ;
for ( j = 1 ; j <= len ; j ++ ) {
// printf ( "j = %d , ans = %I64d\n" , j , ans ) ;
if ( li[j] ) now = 0 ;
int k = vec[j] - 'a' ;
if ( !c[k][now] ) {
now = 0 ;
// printf ( "nxt[%d] = %d\n" , j , nxt[j] ) ;
j = nxt[j] - 1 ;
continue ;
}
now = c[k][now] ;
if ( !li[j+1] && g[j+1] ) ans += (ll) val[now] ;
// printf ( "j = %d , now = %d\n" , j , now ) ;
if ( li[j+1] ) {
// if ( j == 10 ) printf ( "cnt[%d] = %d\n" , now , cnt[now] ) ;
ans += (ll) cnt[now] ;
now = 0 ;
}
}
}
char s1[maxn] , s[maxn] ;
int main () {
int n , i , j , k ;
while ( scanf ( "%d" , &n ) != EOF ) {
tot = 0 ;
new_node () ;
int t = 0 ;
for ( i = 1 ; i <= n ; i ++ ) {
scanf ( "%d%s" , &j , s ) ;
strcpy ( s1 , s ) ;
int len = strlen ( s ) ;
reverse ( s1 , s1 + len ) ;
get_p ( s1 ) ;
match ( s , s1 ) ;
insert ( s ) ;
get_p ( s ) ;
match ( s1 , s ) ;
li[t+1] = 1 ;
for ( j = 0 ; j < len ; j ++ ) {
if ( ok[j] == len - j ) g[++t] = 1 ;
else g[++t] = 0 ;
vec[t] = s1[j] ;
if ( j ) li[t] = 0 ;
}
}
// for ( i = 1 ; i < t ; i ++ ) printf ( "%d " , cnt[i] ) ; puts ( "" ) ;
int last = t + 1 ;
// for ( i = 1 ; i <= t ; i ++ ) printf ( "%d " , li[i] ) ; puts ( "" ) ;
for ( i = t ; i >= 1 ; i -- ) {
nxt[i] = last ;
if ( li[i] ) last = i ;
}
// for ( i = 1 ; i <= t ; i ++ ) printf ( "%d " , nxt[i] ) ; puts ( "" ) ;
ans = 0 ;
cal ( t ) ;
printf ( "%I64d\n" , ans ) ;
}
}