题目大意
给定一颗树,其中有 n n n 个节点,根节点为 1 1 1。每次可以选择一个节点,其子树的大小为 i i i,删去该节点子树上的所有节点(不包括该节点),代价为 v a l i val_i vali。
求:当删去了除 1 1 1 节点以外的所有节点时,总代价最小为多少。
做法
考虑树形 dp。
规定 f i , j f_{i,j} fi,j 数组表示在在 i i i 节点的子树中删去 j j j 个节点的代价最小值。显然,答案为 f 1 , n − 1 f_{1,n-1} f1,n−1,因为无法删去 1 1 1 节点。
考虑删除过程。下图是样例的节点连接情况。作图来自 graph_editor。
因为最终要删去除了节点 1 1 1 以外所有的节点,所以我们遍历它的儿子。
观察到,如果删去 8 8 8 节点的子树是经济的,所以第一次删去 8 8 8 节点的子树,代价为 v a l 6 = 45104 val_6 = 45104 val6=45104。现在的树如下图。
第二次删去剩余的 1 1 1 的子树,代价为 v a l 3 = 18640 val_3 = 18640 val3=18640。所以答案为 63744 63744 63744。
给出状态转移方程,其中 s o n son son 表示 i i i 节点的儿子, n u m num num 表示子树大小: f i , j = min k ∈ [ max ( 1 , j − n u m i ) , min ( j , n u m s o n − 1 ) ] { f i , j − k + f s o n , k } f_{i,j}=\min \limits _{k \in [\max(1,j-num_i),\min(j,num_{son}-1)]} \{ f_{i,j-k}+f_{son,k} \} fi,j=k∈[max(1,j−numi),min(j,numson−1)]min{fi,j−k+fson,k}。
对于每个节点,在动态规划后还要进行一次答案更新: f i , n u m i − 1 = min j ∈ [ 0 , n u m i − 1 ) { f i , j + v a l n u m i − j } f_{i,num_i-1}=\min \limits _{j \in [0,num_i-1)} \{ f_{i,j}+val_{num_i-j} \} fi,numi−1=j∈[0,numi−1)min{fi,j+valnumi−j}。
最后注意一下,因为我们不知道树的方向,所以我们双向连边。
代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define Getchar() p1==p2 and (p2=(p1=Inf)+fread(Inf,1,1<<21,stdin),p1==p2)?EOF:*p1++
char Inf[1<<21],*p1,*p2;
inline void read(int &x,char c=Getchar())
{
bool f=c!='-';
x=0;
while(c<48 or c>57) c=Getchar(),f&=c!='-';
while(c>=48 and c<=57) x=(x<<3)+(x<<1)+(c^48),c=Getchar();
x=f?x:-x;
}
int n,val[5010],f[5010][5010],num[5010];
vector<int> v[5010];
inline void solve(int pos,int fa,int cnt=0)//双向连边,fa防止向上搜索
{
for(int i=0;i<v[pos].size();i++)
{
if(v[pos][i]!=fa)
{
solve(v[pos][i],pos),cnt+=num[v[pos][i]];
for(int j=cnt-1;j>=1;j--)
for(int k=max((long long)1,j-num[pos]);k<=min(j,num[v[pos][i]]-1);k++)//j不能为负数,否则越界
f[pos][j]=min(f[pos][j],f[pos][j-k]+f[v[pos][i]][k]);
num[pos]+=num[v[pos][i]];
}
}
for(int i=0;i<num[pos]-1;i++) f[pos][num[pos]-1]=min(f[pos][num[pos]-1],f[pos][i]+val[num[pos]-i]);
}
signed main()
{
read(n),memset(f,0x3f,sizeof(f));
for(int i=1;i<=n;i++) f[i][0]=0,num[i]=1;
for(int i=2;i<=n;i++) read(val[i]);
for(int i=1,x,y;i<n;i++) read(x),read(y),v[x].push_back(y),v[y].push_back(x);
solve(1,0),cout<<f[1][num[1]-1];//这里无所谓,因为节点 1 的子树中有 num[1]=n 个节点
return 0;
}