题目:BZOJ3926.
题目大意:给定一棵
n
n
n个节点的树和数字
c
c
c,每个节点都有一个点权
a
i
a_i
ai.现在要求这棵树上所有本质不同的链的数量(两条链本质不同仅当两条链分别写成字符串的形式不相同).
1
≤
n
≤
1
0
5
,
1
≤
c
≤
10
1\leq n\leq 10^5,1\leq c\leq 10
1≤n≤105,1≤c≤10,度为
1
1
1的节点数量
≤
20
\leq 20
≤20.
求本质不同的子串数量可以通过SAM来求解,但是现在问题到了树上怎么办.
考虑把整棵树看成一个Trie树,建立广义SAM.但是这样子的链必定是从一个节点到一个节点的祖先,要求两个端点LCA不为端点的链怎么办?
看到度为 1 1 1的节点数量 ≤ 20 \leq 20 ≤20,想到一条链必定会在以某个度为 1 1 1的节点为根时变成一个点到祖先的形式,所以大力以度为 1 1 1的节点为根分别处理20棵Trie树,这样子就可以把所有Trie都塞到一个SAM里,算本质不同的子串数量就行了.
时空复杂度 O ( 20 n Σ ) O(20n\Sigma) O(20nΣ).
代码如下:
#include<bits/stdc++.h>
using namespace std;
#define Abigail inline void
typedef long long LL;
const int N=100000,C=10,L=20;
int n,c;
struct automaton{
int s[C],len,par;
}tr[N*L*2+9];
int cn;
void Build_sam(){cn=1;}
int extend(int x,int last){
int p=last;
if (tr[p].s[x]){
int q=tr[p].s[x];
if (tr[p].len+1==tr[q].len) return q;
else {
tr[++cn]=tr[q];tr[cn].len=tr[p].len+1;
tr[q].par=cn;
while (p&&tr[p].s[x]==q) tr[p].s[x]=cn,p=tr[p].par;
return cn;
}
}else{
int np=++cn;
tr[np].len=tr[p].len+1;
while (p&&!tr[p].s[x]) tr[p].s[x]=np,p=tr[p].par;
if (!p) tr[np].par=1;
else {
int q=tr[p].s[x];
if (tr[p].len+1==tr[q].len) tr[np].par=q;
else {
tr[++cn]=tr[q];tr[cn].len=tr[p].len+1;
tr[q].par=tr[np].par=cn;
while (p&&tr[p].s[x]==q) tr[p].s[x]=cn,p=tr[p].par;
}
}
return np;
}
}
struct side{
int y,next;
}e[N*2+9];
int lin[N+9],top,deg[N+9],a[N+9];
void ins(int x,int y){
e[++top].y=y;
e[top].next=lin[x];
lin[x]=top;
}
int last[N+9];
queue<int>q;
LL ans;
void bfs(int st){
for (int i=1;i<=n;++i) last[i]=-1;
q.push(st);last[st]=extend(a[st],1);
while (!q.empty()){
int t=q.front();q.pop();
for (int i=lin[t];i;i=e[i].next)
if (last[e[i].y]==-1){
last[e[i].y]=extend(a[e[i].y],last[t]);
q.push(e[i].y);
}
}
}
Abigail into(){
scanf("%d%d",&n,&c);
for (int i=1;i<=n;++i)
scanf("%d",&a[i]);
int x,y;
for (int i=1;i<n;++i){
scanf("%d%d",&x,&y);
ins(x,y);ins(y,x);
++deg[x];++deg[y];
}
}
Abigail work(){
Build_sam();
for (int i=1;i<=n;++i)
if (deg[i]==1) bfs(i);
for (int i=2;i<=cn;++i)
ans+=(LL)tr[i].len-tr[tr[i].par].len;
}
Abigail outo(){
printf("%lld\n",ans);
}
int main(){
into();
work();
outo();
return 0;
}