### 题目
幽香是全幻想乡里最受人欢迎的萌妹子,这天,是幽香的2600岁生日,无数幽香的粉丝到了幽香家门前的太阳花田上来为幽香庆祝生日。
粉丝们非常热情,自发组织表演了一系列节目给幽香看。幽香当然也非常高兴啦。
这时幽香发现了一件非常有趣的事情,太阳花田有n块空地。在过去,幽香为了方便,在这n块空地之间修建了n-1条边将它们连通起来。也就是说,这n块空地形成了一个树的结构。
有n个粉丝们来到了太阳花田上。为了表达对幽香生日的祝贺,他们选择了c中颜色的衣服,每种颜色恰好可以用一个0到c-1之间的整数来表示。并且每个人都站在一个空地上,每个空地上也只有一个人。这样整个太阳花田就花花绿绿了。幽香看到了,感觉也非常开心。
粉丝们策划的一个节目是这样的,选中两个粉丝A和B(A和B可以相同),然后A所在的空地到B所在的空地的路径上的粉丝依次跳起来(包括端点),幽香就能看到一个长度为A到B之间路径上的所有粉丝的数目(包括A和B)的颜色序列。一开始大家打算让人一两个粉丝(注意:A,B和B,A是不同的,他们形成的序列刚好相反,比如红绿蓝和蓝绿红)都来一次,但是有人指出这样可能会出现一些一模一样的颜色序列,会导致审美疲劳。
于是他们想要问题,在这个树上,一共有多少可能的不同的颜色序列(子串)幽香可以看到呢?
太阳花田的结构比较特殊,只与一个空地相邻的空地数量不超过20个。
输入
第一行两个正整数n,c。表示空地数量和颜色数量。
第二行有n个0到c-1之间,由空格隔开的整数,依次表示第i块空地上的粉丝的衣服颜色。(这里我们按照节点标号从小到大的顺序依次给出每块空地上粉丝的衣服颜色)。
接下来n-1行,每行两个正整数u,v,表示有一条连接空地u和空地v的边。
输出
一行,输出一个整数,表示答案。
样例输入
7 3
0 2 1 2 1 0 0
1 2
3 4
3 5
4 6
5 7
2 5
样例输出
30
分析
这道题目很明显是一道广义后缀自动机的题目,由于叶子数 ≤ 20,所以可以以每个叶子节点为根建一个trie,在上面建广义后缀自动机,之后统计上面有几个本质不同的子串即可,即 ∑ni=1max[i]−min[i]+1=max[i]−max[fa[i]]
完整代码
#include<bits/stdc++.h>
#define maxn 100010
#define maxm 200010
#define maxt 2000010
//#define DEBUG
using namespace std;
int n,m,num=0;
int col[maxn],pre[maxm],now[maxm],deg[maxm],to[maxm];
struct SAM
{
int root,tot,last;
int son[maxt][15],fa[maxt],maxl[maxt];
void init() { root=1,tot=1,last=1; }
void add(int x) //加点方案基本相同
{
int p=last,q=son[p][x];
if(q)
{
if(maxl[q]==maxl[p]+1) last=q;
else
{
int nq=++tot;
maxl[nq]=maxl[p]+1;
memcpy(son[nq],son[q],sizeof(son[q]));
fa[nq]=fa[q];
fa[q]=nq;
for( ; son[p][x]==q ; p=fa[p] ) son[p][x]=nq;
last=nq;
}
}
else
{
int np=++tot;
maxl[np]=maxl[p]+1;
for( ; p&&!son[p][x] ; p=fa[p] ) son[p][x]=np;
if(!p) fa[np]=root;
else
{
int q=son[p][x];
if(maxl[q]==maxl[p]+1) fa[np]=q;
else
{
int nq=++tot;
maxl[nq]=maxl[p]+1;
memcpy(son[nq],son[q],sizeof(son[q]));
fa[nq]=fa[q];
fa[q]=fa[np]=nq;
for( ; son[p][x]==q ; p=fa[p] ) son[p][x]=nq;
}
}
last=np;
}
}
void print()
{
long long ans=0;
for(int i=1;i<=tot;++i)
ans+=maxl[i]-maxl[fa[i]];
printf("%lld\n",ans);
}
} sam ;
inline int read()
{
char ch;
int read=0,sign=1;
do
ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-');
if(ch=='-') sign=-1,ch=getchar();
while(ch>='0' && ch<='9')
{
read=read*10+ch-'0';
ch=getchar();
}
return sign*read;
}
inline void add(int u,int v)
{
pre[++num]=now[u];
now[u]=num;
to[num]=v;
deg[v]++;
}
void dfs(int from,int bef)
{
#ifdef DEBUG
printf("dfs(%d,%d)\n",from,bef);
#endif
sam.add(col[from]);
int t=sam.last;
for(int y=now[from];y;y=pre[y])
if(to[y]!=bef) dfs(to[y],from),sam.last=t;
}
int main()
{
n=read(),m=read();
for(int i=1;i<=n;++i) col[i]=read();
for(int i=1,u,v;i<n;++i) u=read(),v=read(),add(u,v),add(v,u);
sam.init();
for(int i=1;i<=n;++i)
if(deg[i]==1) sam.last=sam.root,dfs(i,0);
#ifdef DEBUG
cerr<<"tot="<<sam.tot<<endl;
#endif
sam.print();
return 0;
}