链接:https://www.nowcoder.com/acm/contest/91/B
来源:牛客网
题目描述
在埃森哲,员工培训是最看重的内容,最近一年,我们投入了 9.41 亿美元用于员工培训和职业发展。截至 2018 财年末,我们会在全球范围内设立 100 所互联课堂,将互动科技与创新内容有机结合起来。按岗培训,按需定制,随时随地,本土化,区域化,虚拟化的培训会让你快速取得成长。小埃希望能通过培训学习更多ACM 相关的知识,他在培训中碰到了这样一个问题,
给定一棵n个节点的树,并且根节点的编号为p,第i个节点有属性值vali, 定义F(i): 在以i为根的子树中,属性值是vali的合约数的节点个数。y 是 x 的合约数是指 y 是合数且 y 是 x 的约数。小埃想知道∑i*F(i) ,(i=0~n)对1000000007取模后的结果.
输入描述:
输入测试组数T,每组数据,输入n+1行整数,第一行为n和p,1<=n<=20000, 1<=p<=n, 接下来n-1行,每行两个整数u和v,表示u和v之间有一条边。第n+1行输入n个整数val1, val2,…, valn,其中1<=vali<=10000,1<=i<=n.
输出描述:
对于每组数据,输出一行,包含1个整数, 表示对1000000007取模后的结果
示例1
输入
2
5 4
5 3
2 5
4 2
1 3
10 4 3 10 5
3 3
1 3
2 1
1 10 1
输出
11
2
备注:
n>=10000的有20组测试数据
思路:对[1,10000]每个数字的合约数作预处理,然后用dfs序求解
#include<stdio.h>
#include<algorithm>
#include<math.h>
#include<iostream>
#include<string.h>
#include<vector>
#include<set>
using namespace std;
typedef long long ll;
const int mod = 1e9+7;
const int maxn = 2e4+10;
vector<int>g[maxn];
vector<int>vt[maxn];
bool vis[maxn];
int w[maxn];
int n,p;
ll cnt[maxn];
ll ans;
void init()//预处理[1,10000]每个数的合约数
{
//vis[]=1质数,vis[]=0合数
for(int i=1;i<=10000;i++) vis[i]=1;
for(int i=2;i<=10000;i++)
{
if(!vis[i]) continue;
for(int j=i+i;j<=10000;j+=i)
vis[j]=0;
}
for(int i=2;i<=10000;i++)
if(!vis[i])
{
for(int j=i;j<=10000;j+=i)
vt[j].push_back(i);
}
}
//i*f[i]相当于f[i]个i累加,即对于结点u来说,它的子节点中每个合约数结点都加上一个u
//整个树中,一个结点i的cnt[i]等于结点x,y,z……的和(i是x,y,z的合约数)
void dfs(int u,int fa)
{
//将这个点的合约数全部加上这个点,这个点的合约数可能存在于子结点也可能是兄弟结点
for(int i=0;i<vt[w[u]].size();i++)
{
int v=vt[w[u]][i];
cnt[v]=(cnt[v]+u)%mod;
}
//搜到这个点时,这个点已经加过所有满足条件的父结点
ans=(ans+cnt[w[u]])%mod;
for(int i=0;i<g[u].size();i++)
{
if(g[u][i]!=fa)
dfs(g[u][i],u);
}
//防止搜索兄弟结点时加上这个点
for(int i=0;i<vt[w[u]].size();i++)
{
int v=vt[w[u]][i];
cnt[v]=(cnt[v]-u)%mod;
}
}
int main()
{
int T;
scanf("%d",&T);
init();
while(T--)
{
ans=0;
int n,p;
scanf("%d%d",&n,&p);
for(int i=1;i<=n;i++) g[i].clear();
for(int i=1;i<=n;i++) cnt[i]=0;
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
for(int i=1;i<=n;i++) scanf("%d",&w[i]);
dfs(p,-1);
printf("%lld\n",ans);
}
return 0;
}