题意
现在有一棵树,这棵树上每个点有一个权值xxx,现在想要在一些点上增加一个权值yyy。希望任何一个点的权值xxx能够在一条路上这条路上两个端点的权值yyy都大于等于xxx。请问权值yyy总和最小是多少?
题解
先明确一个简单的定理:需要增加权值yyy的点一定是叶子节点。
证明:叶子节点作为整棵树的端点一定要增加权值yyy。如果在叶子节点之外的节点有一个点也增加了权值yyy,那么分为两种情况。一种情况是叶子节点的yyy大于等于当前节点的xxx值,这种情况下增加当前节点的yyy值没有意义;另一种情况是叶子节点的yyy小于当前节点的xxx值,这时我们有两种策略,一种是将叶子节点的yyy更新为xxx,另一种是将当前节点的yyy值设置为xxx值叶子节点的yyy值不变,很明显前者比较小。
上面的证明或许有些不严谨,但如果看了我下面的讨论之后你的思路可能会更加清晰。
我们选择一个权值xxx最大的节点作为根节点。
如果根节点只有一个子节点,它就成了树的一个端点,它的yyy值必然设定为它的xxx值。
另一方面,如果根节点有两个及以上的子节点,我们可以在它的两个子节点构成的子树中选择两个合适的叶子节点(何为合适等会儿会讲解)将它们的yyy值设定为xxx值。
然后是接下来子树的讨论。我们设置一个pairpairpair数组分别存放当前子树最大的叶子节点及其对应的yyy值。对于一个叶子节点我们将它们的yyy值设置为xxx值,对于一棵子树的根节点我们选择遍历它的所有子节点找到yyy值最大的节点及其yyy值,用该yyy值和根节点的xxx的较大值来更新该叶子节点的yyy值——这时当前节点一个端点的更新,另一个端点可以直接选择经过维护整棵树的根节点的两个端点的任意一个,因为整棵树的根节点xxx权值最大,所以能够维护整棵树的根节点必然也能够维护子树的根节点。
最后来讲一下我们上面提到的合适的叶子节点。因为我们要使权值和最小,所以我们要遍历根节点所有叶子节点找到最大值和次大值及其对应的叶子节点,并更新它的yyy值。
#include<bits/stdc++.h>
using namespace std;
using i64=long long;
signed main(){
ios_base::sync_with_stdio(false);
cin.tie(0);cout.tie(0);
int n;
cin>>n;
vector<i64> h(n+1,0);
vector<vector<int>> g;
g.resize(n+1);
for(int i=1;i<=n;i++) cin>>h[i];
for(int i=1,u,v;i<n&&cin>>u>>v;i++) g[u].push_back(v),g[v].push_back(u);
int root=1;
for(int i=2;i<=n;i++) if(h[i]>h[root]) root=i;
vector<pair<i64,int>> maxi(n+1,{0,0});
auto dfs=[&g,&root,&h,&maxi](auto self,int fa,int x)->void{
for(auto v:g[x]) if(v!=fa) self(self,x,v);
if(root==x){
if(g[root].size()==1){
maxi[x].first=max(maxi[x].first,h[x]),maxi[x].second=x;
maxi[maxi[g[root][0]].second].first=max(maxi[maxi[g[root][0]].second].first,h[x]);
}else{
int maximum1=0,maximum2=0;
int maxid1=0,maxid2=0;
for(auto v:g[x]){
if(maximum1<=maxi[v].first){
maximum2=maximum1,maximum1=maxi[v].first;
maxid2=maxid1,maxid1=v;
}else if(maximum2<maxi[v].first){
maximum2=maxi[v].first,maxid2=v;
}
}
maxi[maxi[maxid1].second].first=max(maxi[maxi[maxid1].second].first,h[x]);
maxi[maxi[maxid2].second].first=max(maxi[maxi[maxid2].second].first,h[x]);
}
return ;
}
if(g[x].size()==1){
maxi[x].first=max(maxi[x].first,h[x]);
maxi[x].second=x;
return;
}
int maxid=0,maximum=0;
for(auto v:g[x]){
if(v==fa) continue;
if(maximum<maxi[v].first) maximum=maxi[v].first,maxid=maxi[v].second;
}
maxi[x].first=maxi[maxi[maxid].second].first=max(maxi[maxi[maxid].second].first,h[x]);
maxi[x].second=maxi[maxid].second;
};
dfs(dfs,-1,root);
i64 ans=0;
for(int i=1;i<=n;i++) if(g[i].size()==1) ans+=maxi[i].first;
cout<<ans<<endl;
return 0;
}