好久没有搞过树形dp的题了,它对新人很不友好,我就来补一发超详细的题解吧。
一、题目
点此看题
题意
给定一棵树,其中每个号节点如果被点亮,就会对周围相连的节点发出ci格能量,点亮第i个节点需要的能量点数为di。问能点亮整棵树的最小能量花费。
数据范围
对于
50
%
50\%
50%的数据,
max
c
i
<
=
1
,
n
<
=
100000
\max{ci}<=1,n<=100000
maxci<=1,n<=100000
对于另外
50
%
50\%
50%的数据,
max
c
i
<
=
5
,
n
<
=
2000
\max{ci}<=5,n<=2000
maxci<=5,n<=2000
二、解法
第一眼看到这个数据范围,觉得特别神奇,我们考虑数据分治。
0x01 贪心
先来考虑第一个范围,因为此时的
c
c
c只包含0或1,我们考虑
c
=
1
c=1
c=1的点的点亮顺序。无论怎么点亮,减少的能量花费总量都是一定的(可以自举例子加深理解),所以最优解和顺序无关,我们只需优先点亮
c
=
1
c=1
c=1的点即可。
0x02 树形dp
第二个范围就没有那么容易了,受贪心的启发,我们考虑点亮顺序。利用树形dp,我们对局部的点亮顺序进行考虑我们定义
d
p
[
i
]
[
0
/
1
]
dp[i][0/1]
dp[i][0/1]为对于第
i
i
i个节点先染 它
/
/
/它的父亲,并把
i
i
i及
i
i
i的子树染完的最小花费。
显然答案为
d
p
[
1
]
[
0
]
dp[1][0]
dp[1][0](
1
1
1的父亲不存在,用它统计最优答案)。
现在我们考虑对
d
p
dp
dp数组的转移。
由于
c
c
c并不是很大,我们可以把它作为一个维度,为了辅助转移,我们再定义
t
m
p
[
s
]
tmp[s]
tmp[s]为染完
u
u
u的子树,并且接受了儿子们值为
s
s
s的能量传递的最小花费(这里并不用去管
u
u
u是否被染色),我们先遍历完子树,再用滚动数组的形式合并子树的信息,详细操作见下。
memset(tmp,0x3f,sizeof tmp);
tmp[0][0]=0;//初始化
for(int i=f[u];i;i=e[i].next)//枚举每个儿子
{
int v=e[i].v;
if(v==fa) continue;
cur^=1;//滚动数组,用tmp[cur^1]更新tmp[cur]
memset(tmp[cur],0x3f,sizeof tmp[cur]);
for(int j=0;j<=sum-c[v];j++)//此时我们可以使用dp[v]
{
tmp[cur][j+c[v]]=min(tmp[cur][j+c[v]],tmp[cur^1][j]+dp[v][0]);//先点亮v
tmp[cur][j]=min(tmp[cur][j],tmp[cur^1][j]+dp[v][1]);//先点亮u
}
}
拿到了
t
m
p
tmp
tmp之后,更新就要简单一些了,我们可写出如下转移:
d
p
[
u
]
[
0
]
=
m
i
n
(
d
p
[
u
]
[
0
]
,
m
a
x
(
t
m
p
[
i
]
,
t
m
p
[
i
]
−
i
+
d
[
u
]
)
)
;
dp[u][0]=min(dp[u][0],max(tmp[i],tmp[i]-i+d[u]));
dp[u][0]=min(dp[u][0],max(tmp[i],tmp[i]−i+d[u]));
d
p
[
u
]
[
1
]
=
m
i
n
(
d
p
[
u
]
[
1
]
,
m
a
x
(
t
m
p
[
i
]
,
t
m
p
[
i
]
−
i
+
d
[
u
]
−
c
[
f
a
]
)
)
;
dp[u][1]=min(dp[u][1],max(tmp[i],tmp[i]-i+d[u]-c[fa]));
dp[u][1]=min(dp[u][1],max(tmp[i],tmp[i]−i+d[u]−c[fa]));
其中
i
i
i是从儿子得到的能量,我们防止能量溢出(即有多余的能量),故取max。
0x3f 代码
作者口胡可能难以理解,更多请看完整版代码,也可以单独向作者提问qwq。
#include <cstdio>
#include <cstring>
const int MAXN = 100005;
const int MAXM = 2005;
int read()
{
int x=0,flag=1;
char c;
while((c=getchar())<'0' || c>'9') if(c=='-') flag=-1;
while(c>='0' && c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
return x*flag;
}
int n,tot,ans,mc,d[MAXN],c[MAXN],f[MAXN];
struct edge
{
int v,next;
} e[MAXN*2];
int max(int x,int y)
{
if(x<y) return y;
return x;
}
int min(int x,int y)
{
if(x<y) return x;
return y;
}
int dp[MAXM][2],tmp[2][5*MAXM];
void dfs(int u,int fa)
{
int sum=0,cur=0;
for(int i=f[u]; i; i=e[i].next)
{
int v=e[i].v;
if(v==fa) continue;
dfs(v,u);
sum+=c[v];
}
memset(tmp,0x3f,sizeof tmp);
tmp[0][0]=0;
for(int i=f[u]; i; i=e[i].next)
{
int v=e[i].v;
if(v==fa) continue;
cur^=1;
memset(tmp[cur],0x3f,sizeof tmp[cur]);
for(int j=0; j<=sum-c[v]; j++)
{
tmp[cur][j+c[v]]=min(tmp[cur][j+c[v]],tmp[cur^1][j]+dp[v][0]);
tmp[cur][j]=min(tmp[cur][j],tmp[cur^1][j]+dp[v][1]);
}
}
dp[u][0]=dp[u][1]=0x3f3f3f3f;
for(int i=0; i<=sum; i++)
{
dp[u][0]=min(dp[u][0],max(tmp[cur][i],tmp[cur][i]-i+d[u]));
dp[u][1]=min(dp[u][1],max(tmp[cur][i],tmp[cur][i]-i+d[u]-c[fa]));
}
return ;
}
int main()
{
n=read();
for(int i=1; i<=n; i++)
d[i]=read();
for(int i=1; i<=n; i++)
c[i]=read(),mc=max(mc,c[i]);
for(int i=1; i<n; i++)
{
int u=read(),v=read();
e[++tot]=edge{u,f[v]},f[v]=tot;
e[++tot]=edge{v,f[u]},f[u]=tot;
}
if(mc<=1)
{
for(int i=1; i<=n; i++)
if(c[i]==1)
{
ans+=d[i];
d[i]=0;
for(int j=f[i]; j; j=e[j].next)
d[e[j].v]--;
}
for(int i=1; i<=n; i++)
if(d[i]>0)
ans+=d[i];
printf("%d\n",ans);
return 0;
}
dfs(1,0);
printf("%d\n",dp[1][0]);
return 0;
}