题目描述
秀秀有一棵带n个顶点的树T,每个节点有一个点权ai。
有一天,她想拥有两棵树,于是她从T中删去了一条边。
第二天,她认为三棵树或许会更好一些。因此,她又从她拥有的某一棵树中删去了一条边。
如此往复,每一天秀秀都会删去一条尚未被删去的边,直到她得到由n棵只有一个点的树构成的森林。
秀秀定义一条简单路径(节点不重复出现的路径)的权值为路径上所有点的权值之和,一棵树的直径为树上权值最大的简单路径。秀秀认为树最重要的特征就是它的直径。所以她想请你算出任一时刻她拥有的所有树的直径的乘积。因为这个数可能很大,你只需输出这个数对1e9+7取模之后的结果即可。
题解
有一个性质,两棵树合并后的直径的端点还是在原来的四个点上
证明:不在原来所以肯定存在跨树的一条路径,然而从那个连接点开始在那个子树跑DFS到最远点那个点一定还是原来直径上的点
代码
#include <bits/stdc++.h>
#define maxn 100005
#define INF 0x3f3f3f3f
#define LL long long
#define mod 1000000007
#define re register
using namespace std;
int read(){
int res,f=1; char c;
while(!isdigit(c=getchar())) if(c=='-') f=-1; res=(c^48);
while(isdigit(c=getchar())) res=(res<<3)+(res<<1)+(c^48);
return res*f;
}
struct EDGE{
int u,v,nxt;
}e[maxn<<1];
struct NODE{
int s,t,d;
bool operator < (const NODE&rhs)const{return d<rhs.d;}
}D[maxn];
int n,cnt=1,head[maxn],flag[maxn<<1];
void add(int u,int v){
e[++cnt]=(EDGE){u,v,head[u]};
head[u]=cnt;
}
LL Pow(LL x,int k){
LL res=1;
while(k){
if(k&1) res=res*x%mod;
x=x*x%mod; k>>=1;
}
return res;
}
int pre[maxn],fa[maxn],size[maxn],sum[maxn],son[maxn],top[maxn],dep[maxn],del[maxn];
int find(int x){return x==pre[x]?x:pre[x]=find(pre[x]);}
LL ans[maxn],w[maxn];
void DFS1(int u,int far){
size[u]=1; dep[u]=dep[far]+1; fa[u]=far;
for(re int i=head[u];~i;i=e[i].nxt){
int v=e[i].v; if(v==far) continue;
sum[v]=sum[u]+w[v];
DFS1(v,u);
size[u]+=size[v];
if(size[son[u]]<size[v]) son[u]=v;
}
}
void DFS2(int u,int far,int tp){
top[u]=tp; if(son[u]) DFS2(son[u],u,tp);
for(re int i=head[u];~i;i=e[i].nxt){
int v=e[i].v; if(v==far || v==son[u]) continue;
DFS2(v,u,v);
}
}
int DIS(int x,int y){
int X=x,Y=y;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
int lca=dep[x]<dep[y]?x:y;
return sum[X]+sum[Y]-2*sum[lca]+w[lca];
}
int main(){
memset(head,-1,sizeof head);
n=read();
ans[n]=1;
for(re int i=1;i<=n;i++){
ans[n]=ans[n]*(w[i]=read())%mod;
D[i].s=D[i].t=i;
D[i].d=w[i];
pre[i]=i;
}
for(re int i=1,u,v;i<n;i++){
u=read(); v=read();
add(u,v); add(v,u);
}
sum[1]=w[1];
DFS1(1,1);
DFS2(1,1,1);
for(re int i=1;i<n;i++) scanf("%d",&del[i]);
for(re int i=n-1;i;i--){
int x=e[del[i]<<1].u,y=e[del[i]<<1].v,fx=find(x),fy=find(y);
ans[i]=ans[i+1]*Pow(D[fx].d,mod-2)%mod*Pow(D[fy].d,mod-2)%mod;
int s1=D[fx].s,t1=D[fx].t,s2=D[fy].s,t2=D[fy].t;
D[fx]=max(D[fx],D[fy]);
D[fx]=max(D[fx],(NODE){s1,s2,DIS(s1,s2)});
D[fx]=max(D[fx],(NODE){s1,t2,DIS(s1,t2)});
D[fx]=max(D[fx],(NODE){s2,t1,DIS(s2,t1)});
D[fx]=max(D[fx],(NODE){t1,t2,DIS(t1,t2)});
pre[fy]=fx;
ans[i]=ans[i]*D[fx].d%mod;
}
for(int i=1;i<=n;i++){
printf("%lld\n",ans[i]);
}
return 0;
}