这道题首先要读明白题,就是说叶子节点不超过20个,那么我们就可以以每一个叶子节点为根建一个广义后缀自动机,这样就一定能表示出来所有的子串,然后统计一下答案就可以啦。
(广义后缀自动机就是把好多串放到一块,每次都从root开始建后缀自动机,但由于这道题是一棵树,所以我们只需先把节点开出来,在dfs的过程中在dfs的出发点之后新插入一个字符)
#include<cstdio>
#include<cstdlib>
#include<cmath>
#include<ctime>
#include<cstring>
#include<string>
#include<iomanip>
#include<iostream>
#include<algorithm>
using namespace std;
int fa[1000000];
struct bian
{
int l,r;
}a[1000000];
struct sam
{
sam *parent,*son[10];
int max_len;
sam(int _=0):max_len(_),parent(0x0){memset(son,0x0,sizeof(son));}
}*root=new sam();
long long ans=0;
sam *my_insert(sam *p,int x)
{
if(p->son[x] && p->son[x]->max_len==p->max_len+1) return p->son[x];
sam* np=new sam(p->max_len+1);
while(p && !p->son[x])
{
p->son[x]=np;
p=p->parent;
}
if(!p) np->parent=root;
else
{
sam *q=p->son[x];
if(q->max_len==p->max_len+1) np->parent=q;
else
{
sam *nq=new sam(p->max_len+1);
ans-=q->max_len-q->parent->max_len;
nq->parent=q->parent;
memcpy(nq->son,q->son,sizeof(nq->son));
q->parent=nq;np->parent=nq;
ans+=nq->max_len-nq->parent->max_len;
ans+=q->max_len-q->parent->max_len;
while(p && p->son[x]==q)
{
p->son[x]=nq;
p=p->parent;
}
}
}
ans+=np->max_len-np->parent->max_len;
//cout<<np->max_len<<" "<<np->parent->max_len<<endl;
return np;
}
int fir[1000000];
int nex[1000000];
int tot=0;
sam *point[1000000];
int v[1000000];
int rudu[1000000];
void add_edge(int l,int r)
{
a[++tot].l=l;
a[tot].r=r;
nex[tot]=fir[l];
fir[l]=tot;
rudu[l]++;
rudu[r]++;
}
void dfs(int u,int fro)
{
point[u]=my_insert(point[fro],v[u]);
for(int o=fir[u];o!=0;o=nex[o]) if(a[o].r!=fro) dfs(a[o].r,u);
}
int main()
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&v[i]);
for(int i=1;i<n;i++)
{
int l,r;
scanf("%d%d",&l,&r);
add_edge(l,r);
add_edge(r,l);
}
point[0]=root;
for(int i=1;i<=n;i++)
{
if(rudu[i]==2) dfs(i,0);
}
cout<<ans;
return 0;
}