可将点权转化为三进制
树的分治:根据子树的节点数最小 求出树的重心 保证了复杂度尽量小
对于重心节点 的一棵子树搜一遍得到所有到那个子结点的权值积 与已经搜过的权值积 相乘 然后计算答案
#include<bits/stdc++.h>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cstdio>
#include<cstring>
using namespace std;
typedef long long LL;
const int INF = 1<<30;
const int MAXN = 50005;
const int mod = 100000000;
int n,p[66],k;
int vis[MAXN],siz[MAXN],mx[MAXN],mi,ans,rot;
vector<int>G[MAXN];
map<LL,int> sta;
map<LL,int>::iterator it;
LL val[MAXN];
LL gao(LL w)
{
LL ret=0,bas=1;
for(int j=0;j<k;j++)
{
int t=0;
while(w%p[j]==0){
w/=p[j];
t++;
}
ret+=(t%3)*bas;
bas*=3;
}
return ret;
}
LL Xor(LL x,LL y){
LL res=0,bas=1;
for(int i=0;i<k;i++){
res+=((x%3)+(y%3))%3*bas;
bas*=3;
x/=3;y/=3;
}
return res;
}
LL inv(LL x)
{
LL res=0,bas=1;
for(int i=0;i<k;i++)
{
res+=((3-(x%3))%3)*bas;
bas*=3;
x/=3;
}
return res;
}
void calsize(int u,int pre)
{
siz[u]=1;mx[u]=0;
for(int i=0;i<G[u].size();i++){
int v=G[u][i];
if(v==pre||vis[v]) continue;
calsize(v,u);
siz[u]+=siz[v];
if(siz[v]>mx[u]) mx[u]=siz[v];
}
}
void getroot(int u,int pre,int father){
if(siz[father]-siz[u]>mx[u]) mx[u]=siz[father]-siz[u];
if(mx[u]<mi) mi=mx[u],rot=u;
for(int i=0;i<G[u].size();i++){
int v=G[u][i];
if(v==pre||vis[v]) continue;
getroot(v,u,father);
}
}
void mul(int u,int pre,LL d,map<LL,int>&ds)
{
if(ds.count(d)) ds[d]++;
else ds[d]=1;
for(int i=0;i<G[u].size();i++)
{
int v=G[u][i];
if(v==pre||vis[v]) continue;
mul(v,u,Xor(d,val[v]),ds);
}
}
int dfs(int u)
{
mi=INF;
calsize(u,u);
getroot(u,u,u);
int root=rot;
vis[root]=1;
for(int i=0;i<G[root].size();i++){
if(vis[G[root][i]]) continue;
dfs(G[root][i]);
}
sta.clear();
sta[val[root]]=1;map<LL,int>tds;
for(int i=0;i<G[root].size();i++)
{
int v=G[root][i];
if(vis[v]) continue;
tds.clear();
mul(v,root,val[v],tds);
it=tds.begin();
while(it!=tds.end())
{
LL rev=inv((*it).first);
if(sta.count(rev))
ans+=sta[rev]*(*it).second;
++it;
}
it=tds.begin();
while(it!=tds.end())
{
LL w=Xor((*it).first,val[root]);
if(sta.count(w))
sta[w]+=(*it).second;
else sta[w]=(*it).second;
++it;
}
}
vis[root]=0;
}
int main()
{
while(scanf("%d",&n)!=EOF)
{
int a,b;
ans=0;
scanf("%d",&k);
for(int i=1;i<=n;i++) G[i].clear();
for(int i=0;i<k;i++)
scanf("%d",&p[i]);
sort(p,p+k);k=unique(p,p+k)-p;
for(int i=1;i<=n;i++){
scanf("%lld",&val[i]);
val[i]=gao(val[i]);
if(val[i]==0)
ans++;
}
for(int i=0;i<n-1;i++){
scanf("%d%d",&a,&b);
G[a].push_back(b);
G[b].push_back(a);
}
memset(vis,0,sizeof(vis));
dfs(1);
cout<<ans<<endl;
}
return 0;
}
/*
*/