题目:点击打开链接
题意:给n个点,每个点有一个颜色值c[i],给n-1条边,保证能够连成一个树。定义树上每两个点间的距离为这两点间的颜色种类数,求整棵树的所有路径长度和,路径总数为n*(n-1)/2
思路:要求总路径长度和,即求每条路径上颜色的种数的和,可以通过求每种颜色不在某一条路径上的和=ans,最后用路径总数*颜色总数-ans得到最终答案。具体点就是用vector容器存边,对于树上的一个点u,dfs找他的子树中颜色等于c[u]的最高的点,取出u,v的子树大小,相减得到一个不含颜色c[u]的连通块的大小,连通块内每条路径均不包含c[u],便计算累加到ans里。最后对整棵树补充一下所有颜色剩下的连通块。(这里有点想不明白是少了哪些了QAQ) 详见以下代码。
代码:
#include <bits/stdc++.h>
using namespace std;
const int maxn=2e5+10;
typedef long long ll;
ll ans,sz[maxn],sum[maxn],c[maxn],vis[maxn];
vector <int> tree[maxn];
ll dfs(int u,int pa)
{
sz[u]=1;
ll allson=0;
int cnt=tree[u].size();
for(int i=0;i<cnt;i++)
{
int v=tree[u][i];
if(v==pa) continue;
ll last=sum[c[u]];
sz[u]+=dfs(v,u);
ll add=sum[c[u]]-last;
ans+=(sz[v]-add)*(sz[v]-add-1)/2;
allson+=sz[v]-add;
}
sum[c[u]]+=allson+1;
return sz[u];
}
int main()
{
int cas=1;
ll n;
while(~scanf("%lld",&n))
{
ll cnt=0;
memset(vis,0,sizeof(vis));
memset(sum,0,sizeof(sum));
for(int i=1;i<=n;i++)
{
scanf("%lld",&c[i]);
if(vis[c[i]]==0)
{
cnt++;
vis[c[i]]=1;
}
tree[i].clear();
}
for(int i=1;i<=n-1;i++)
{
int u,v;
scanf("%d%d",&u,&v);
tree[u].push_back(v);
tree[v].push_back(u);
}
printf("Case #%d: ",cas++);
if(cnt==1)
{
ans=(n*(n-1))/2;
printf("%lld\n",ans);
}
else
{
ans=0;
dfs(1,-1);
for(int i=1;i<=n;i++)
if(vis[i]==1)
ans+=((n-sum[i])*(n-sum[i]-1))/2;
printf("%lld\n",((n*(n-1))/2)*cnt-ans);
}
}
return 0;
}