题意是给你2组字符串,前者作为前缀,后者作为后缀,然后问你拼起来的单词有几个,相同的不算。
此题有2种做法,第一种是算出所有的以字母k为首部的后缀个数a[i],然后dfs遍历前缀的字典树,如果该位置下没有某个字母k,那么就直接加上a[k],如果有那么直接往下dfs,还有如果有但是是结尾了,那么就要加上1。
举个例子,比如一个点是d,那么应该在这个位置加上除了a[d]以外的其他a[i],那么如果后面没有重复的情况那么a[d]会不会漏加?不会,因为以d为首部的后缀个数可以分解为好几个以d后面一个字母为首部的后缀之和,所以dfs到d后面一个的时候,就会把这些全部加上了。
第二种做法很巧妙,我是看了别人的代码才懂了,算出第一组所有的前缀的个数S1,算出第二组所有后缀个数S2,同时还要算第一组以字母k为末尾的前缀个数X1[k],第二组以字母k为首部的后缀个数X2[k],答案就是s1*s2-(0-25)sigma(X1[ i ] * X2[ i ])。为何可以怎么算,首先把S1×S2肯定会有很多重复,举个栗子,比如前缀abcdk,后缀kefgh,那么这2个组合出现重复的情况是组合成abcdkefgh,这种情况可以出现2次,所以只要前缀后缀出现一个相同的字母就会重复一次,前缀出现a次,后缀出现b次那么就会重复a*b次,所以就有了上面的公式。
AC代码:
#include<cstdio>
#include<ctype.h>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<vector>
#include<cstdlib>
#include<stack>
#include<cmath>
#include<queue>
#include<set>
#include<ctime>
#include<string.h>
#include<string>
using namespace std;
#define ll long long
#define eps 1e-8
#define NMAX 100000
template<class T>
inline void scan_d(T &ret)
{
char c;
ret=0;
while((c=getchar())<'0'||c>'9');
while(c>='0'&&c<='9') ret=ret*10+(c-'0'),c=getchar();
}
int suf[26],mark[26];
ll ans;
struct Trie
{
int ch[NMAX+5][26];
int sz;
Trie(){sz = 1; memset(ch[0],0,sizeof(ch[0]));}
void init(){sz = 1; memset(ch[0],0,sizeof(ch[0]));}
int idx(char c){return c-'a';}
void insert(char *s,int flag)
{
int u = 0, n = strlen(s);
for(int i = 0; i < n; i++)
{
int c = idx(s[i]);
if(!ch[u][c])
{
if(flag == 1) suf[c]++;
memset(ch[sz],0,sizeof(ch[sz]));
ch[u][c] = sz++;
}
u = ch[u][c];
}
}
}prefix,suffix;
void solve(int x)
{
if(x >= prefix.sz) return;
for(int i = 0; i < 26; i++)
if(!prefix.ch[x][i])
{
// if(suf[i]!=0) cout<<x<<" "<<i<<endl;
if(x != 0) ans += suf[i];
}
else
{
if(mark[i]&& x!=0) ans++;
solve(prefix.ch[x][i]);
}
// cout<<ans<<" "<<x<<endl;
}
int main()
{
#ifdef GLQ
freopen("input.txt","r",stdin);
// freopen("o2.txt","w",stdout);
#endif // GLQ
int p,s,i;
while(~scanf("%d%d",&p,&s)&&p+s)
{
memset(suf,0,sizeof(suf));
memset(mark,0,sizeof(mark));
prefix.init();
suffix.init();
char t1[1005],t2[1005];
for(i = 0; i < p; i++)
{
scanf("%s",t1);
prefix.insert(t1,0);
}
for(i = 0; i < s; i++)
{
scanf("%s",t1);
int len = strlen(t1);
for(int j = 0; j < len ;j++)
t2[j] = t1[len-j-1];
mark[t1[len-1]-'a'] = 1;
t2[len] = '\0';
// cout<<t2<<endl;
suffix.insert(t2,1);
}
ans = 0;
solve(0);
printf("%lld\n",ans);
}
return 0;
}