以重心将树分成森林,时间复杂度最坏为O(nlogn)。
ACcode:
#include<cstdio>
#include<cstring>
#include<vector>
#include<map>
using namespace std;
typedef long long LL;
#define pb push_back
const int NS=500010;
int n,k,top;
LL pri[50],val[NS],sta[NS],ans;
int head[NS],to[NS<<1],next[NS<<1];
int vis[NS],vct=0;
int vec[NS];
int que[NS];
int nod[NS],far[NS];
vector<LL> wt;
map<LL,int> shi;
void add(int u,int v)
{
next[top]=head[u];
to[top]=v;
head[u]=top++;
next[top]=head[v];
to[top]=u;
head[v]=top++;
}
void trans(LL &x)
{
LL ret=0,w=1;
for (int i=0;i<k;i++)
{
int num=0;
LL tmp=pri[i];
for (;!(x%tmp);x/=tmp)
num++;
num%=3;
ret+=w*num; w*=3;
}
x=ret;
}
LL combin(LL x,LL y)
{
LL ret=0,w=1;
for (int i=0;i<k;i++)
{
LL t=(x+y)%3;
ret+=w*t;
x/=3,y/=3;w*=3;
}
return ret;
}
LL rev(LL x)
{
LL ret=0,w=1;
for (int i=0;i<k;i++)
{
LL t=(3-x%3)%3;
ret+=w*t;
x/=3; w*=3;
}
return ret;
}
int fheart(int rt)
{
int tot=0;
vis[rt]=++vct;
que[++tot]=rt;
for (int i=1;i<=tot;i++)
{
int cur=que[i];
for (int j=head[cur];j!=-1;j=next[j])
{
int son=to[j];
if (!vec[son] && vis[son]<vct)
{
far[son]=cur;
vis[son]=vct;
que[++tot]=son;
}
}
}
if (tot==1)
{
ans+=(val[rt]==0);
return -1;
}
int ban=tot/2;
for (int i=tot;i>0;i--)
{
int cur=que[i],flag=1;
nod[cur]=1;
for (int j=head[cur];j!=-1;j=next[j])
{
int son=to[j];
if (far[son]==cur && !vec[son])
{
nod[cur]+=nod[son];
if (nod[son]>ban)
flag=0;
}
}
if (flag && tot-nod[cur]<=ban)
return cur;
}
}
void get_son(int rt)
{
wt.clear();
int tot=0;
vis[rt]=++vct;
sta[rt]=val[rt];
que[++tot]=rt;
wt.pb(sta[rt]);
for (int i=1;i<=tot;i++)
{
int cur=que[i];
for (int j=head[cur];j!=-1;j=next[j])
{
int son=to[j];
if (!vec[son] && vis[son]<vct)
{
vis[son]=vct;
sta[son]=combin(sta[cur],val[son]);
wt.pb(sta[son]);
que[++tot]=son;
}
}
}
}
void dfs(int rt)
{
int ht=fheart(rt);
// printf("rt=%d ht=%d\n",rt,ht);
if (ht<0) return ;
vec[ht]=1;
shi.clear();
ans+=(val[ht]==0);
shi[val[ht]]++;
for (int i=head[ht];i!=-1;i=next[i])
{
int son=to[i];
if (!vec[son])
{
get_son(son);
for (int j=0;j<wt.size();j++)
{
LL tmp=rev(wt[j]);
if (shi.find(tmp)!=shi.end())
ans+=shi[tmp];
}
for (int j=0;j<wt.size();j++)
shi[combin(wt[j],val[ht])]++;
}
}
for (int i=head[ht];i!=-1;i=next[i])
if (!vec[to[i]])
dfs(to[i]);
}
int main()
{
while (~scanf("%d",&n))
{
top=vct=0;
memset(vec,0,sizeof(int)*(n+2));
memset(vis,0,sizeof(int)*(n+2));
memset(head,-1,sizeof(int)*(n+2));
scanf("%d",&k);
for (int i=0;i<k;i++)
scanf("%I64d",&pri[i]);
for (int i=1;i<=n;i++)
{
scanf("%I64d",&val[i]);
trans(val[i]);
}
for (int i=1;i<n;i++)
{
int u,v;
scanf("%d %d",&u,&v);
add(u,v);
}
ans=0;
dfs(1);
printf("%I64d\n",ans);
}
return 0;
}