此文章亦适合初学者入门使用,但建议对于每个题目应自己思考约15分钟后回来看题解。
前言
树形dp入门较为简单,作为dp来讲,相较于无规则可循的dp和复杂的状压dp,高维dp。树形dp的树形依赖关系明显,常常能够让人看出是树形dp的题目,但是由于其dp依赖于DFS,所以大多数时候这类dp的难度分为状态设计上和如何配合DFS写出转移。
而作为一个在ICPC的比赛中是常客的知识点,在铜牌之后,树形DP是银牌算法中较为简单可写的算法之一。
树形背包
树形背包是根据树形依赖关系构成的树上分组背包问题。
树形背包顾名思义是在树上做背包的一种树形DP,原理上是基于树形关系的分组背包,将某节点的子节点代表的子树的不同状态视为某个物品的不同状态,自底向上更新选择哪个状态是最优的,以此来达到dp的效果。
在做了一些树形背包题之后发现,原来我从未学会过树形背包,今天才刚刚学会。
我们先提出一个问题,树形背包他标准状态是几维?或者说常用状态是几维的?
答案是3维。
明确这样一个状态dp[0/1][i][j]代表在以i为根的子树中选择j个子节点,并且包含/不包含节点i的情况下,最大/最小的贡献是多少。
大概有如下模板
void DFS(int now,int fa)
{
//初始化
for(int i=head[now];i;i=edge[i].nex)
{
//搜索DFS(to,now)
//DP过程
for(int j=以now为根的子树内能选择的最多的物品数:min(背包容量限制,当前子节点数);j>=在以now为根的子树里你最少选多少个;--j)
{
for(int k=0;k<=以now的儿子节点为根的子节点内最多能选择的物品数;k++)//物品数不一定必须是节点数
//我们根据视儿子节点为一个物品,那么他选择不同的容量的时候,就是这个物品的不同状态,可看做分组背包
{
//状态转移
}
}
siz[now]+=siz[to];//注意这里很关键,必须写在这里,因为上方dp过程中now节点在dp的时候还不能包括目前的这个儿子节点的子节点。
}
}
至于你说dp过程中,第一层循环可以从:全部节点数到最少选择节点数来表示背包容量。
我的评价是:可以,但是会超时,在竞赛里,既然要考察你这个知识点,肯定就不是你用含糊其辞的做法可以过去的,luogu的选课题解大多都是像这样写的,但是他的数据大小极小,而且情况也很特殊,存在极强的误导性。
例题
P2014 [CTSC1997] 选课
入门题。
这道题其实有些误导的倾向,我们这里所说的dp是三维状态但是他只写了两维,也就是0/1维度没有,读者可以思考一下为什么会这样。
答案是题目中要求必须选了父亲节点才能选儿子节点,所以对于每个根节点都是必选的所以就会造成0/1维一直是1,所以没必要去写。
#include<bits/stdc++.h>
using namespace std;
const int N =400;
struct node
{
int nex,to;
};
node edge[N];
int head[N],tot;
void add(int from,int to)
{
edge[++tot].to=to;
edge[tot].nex=head[from];
head[from]=tot;
}
int n,m;
int dp[N][N];
int score[N];
int siz[N];
void DFS(int now)
{
dp[now][1]=score[now];
siz[now]=1;
for(int i=head[now];i;i=edge[i].nex)
{
int to=edge[i].to;
DFS(edge[i].to);
for(int j=siz[now];j>=1;j--)//注意这里j最小是1,因为我们必须选一个节点,就是当前子树的根节点。
{
for(int k=0;k<=siz[to];k++)
{
dp[now][j+k]=max(dp[now][j+k],dp[now][j]+dp[edge[i].to][k]);
}
}
siz[now]+=siz[to];
}
}
signed main()
{
cin>>n>>m;
for(int i=1,a;i<=n;i++)
{
cin>>a>>score[i];
add(a,i);
}
DFS(0);
cout<<dp[0][m+1]<<endl;
return 0;
}
L. Perfect Matchings(树形背包+组合数学)
题意:给定一个树,问你最多能搞出多少个不同的完美匹配(两个点一条线相连,并且这两个点不再和其他点相连)。
这个题拿出来主要是注意在DP的转移过程中
1.我们的物品未必总是节点,也可能是其他和节点相关的信息,比如这个题,完美匹配数就是:节点数/2
2.状态转移依然是dp问题的难点。我们需要配合DFS考虑边界情况等问题。
重点放在DFS上,这题的难点是组合数学,但是这里我们只看DFS
#include<bits/stdc++.h>
#define pb push_back
using namespace std;
#define int long long
const int N=4000+100;
const int mod=998244353;
typedef long long ll;
int n,siz[N],f[N][N][2],g[N];
ll ans;
struct node
{
int nex,to;
}edge[N<<1];
int head[N],tot;
void add(int from,int to)
{
edge[++tot].to=to;
edge[tot].nex=head[from];
head[from]=tot;
}
void DFS(int now,int fa)
{
f[now][0][0]=1;
siz[now]=1;
for(int i=head[now];i;i=edge[i].nex)
{
int to=edge[i].to;
if(fa==to)
continue;
DFS(to,now);
for(int i=siz[now]/2;i>=0;--i)
{
for(int j=0;j<=siz[to]/2;j++)
{
if(j)
{
f[now][i+j][0]=(f[now][i+j][0]+(ll)(f[to][j][0]+f[to][j][1])*f[now][i][0])%mod;
f[now][i+j][1]=(f[now][i+j][1]+(ll)(f[to][j][0]+f[to][j][1])*f[now][i][1])%mod;
}
f[now][i+j+1][1]=(f[now][i+j+1][1]+(ll)f[to][j][0]*f[now][i][0])%mod;
}
}
siz[now]+=siz[to];
}
}
signed main()
{
cin>>n;
for(int i=1,x,y;i<(n<<1);++i)
{
cin>>x>>y;
add(x,y);
add(y,x);
}
g[0]=1;
for(int i=1;i<=n;++i)
g[i]=(ll)g[i-1]*((i<<1)-1)%mod;
DFS(1,0);
for(int i=0;i<=n;++i)
{
if(i&1)
ans=(ans+(ll)(mod-1)*(f[1][i][0]+f[1][i][1])%mod*g[n-i])%mod;
else
ans=(ans+(ll)(f[1][i][0]+f[1][i][1])*g[n-i])%mod;
}
cout<<ans;
}
题意:每个点有一个权值,我们拿走一个点的权值的花费是这个点的权值和他所有子节点的权值之和。
这个题拿出来的目的
1.理解模板的通用性
2.看清初始化应该注意的问题。
#include <bits/stdc++.h>
#define endl '\n'
#define int long long
using namespace std;
const int N = 2e3+100;
struct node
{
int nex,to;
}edge[N<<2];
int tot,n,a[N],head[N];
void add(int from,int to)
{
edge[++tot].nex=head[from];
edge[tot].to=to;
head[from]=tot;
}
int siz[N],dp[2][N][N];
void DFS(int now,int fa)
{
siz[now]=1;
for(int i=0;i<=n;i++)
dp[0][now][i]=dp[1][now][i]=1e18;
dp[1][now][1]=a[now];
dp[0][now][0]=0;
for(int i=head[now];i;i=edge[i].nex)
{
int to=edge[i].to;
if(to==fa)
continue;
DFS(to,now);
for(int j=siz[now];j>=0;j--)
{
for(int k=0;k<=siz[to];k++)
{
dp[0][now][j+k]=min(dp[0][now][j+k],dp[0][now][j]+dp[0][to][k]);
dp[0][now][j+k]=min(dp[0][now][j+k],dp[0][now][j]+dp[1][to][k]);
dp[1][now][j+k]=min(dp[1][now][j+k],dp[1][now][j]+dp[0][to][k]);
dp[1][now][j+k]=min(dp[1][now][j+k],dp[1][now][j]+a[to]+dp[1][to][k]);
}
}
siz[now]+=siz[to];
}
}
void init()
{
for(int i=0;i<=n;i++)
head[i]=0;
for(int i=0;i<=tot;i++)
edge[i].nex=edge[i].to=0;
tot=0;
}
signed main()
{
//cin.tie(0);cout.tie(0);ios::sync_with_stdio(0);
int t;
for(cin>>t;t;t--)
{
cin>>n;
for(int i=2,x;i<=n;i++)
{
cin>>x;
add(i,x);
add(x,i);
}
for(int i=1;i<=n;i++)
cin>>a[i];
DFS(1,0);
for(int i=0;i<=n;i++)
cout<<min(dp[0][1][n-i],dp[1][1][n-i])<<" ";
cout<<endl;
init();
}
return 0;
}
一般树形DP
其实我也不清楚是否有必要把一般的树形DP和树上背包分开,但是这两者本质上还是有些许区别的,从状态设计上和转移方程上都不是很相近。所以我们把树形背包单独拿出来不和树形DP混为一谈。
而树形DP,我认为是比树形背包简单的,他只是单纯的利用树形依赖关系去做DP,就是DFS上DP,背包还要理解分组背包的思想才能说会了树形背包。
树形DP也有一般情况下的模板如下。
void DFS(int now,int fa)
{
for(int i=head[now];i;i=edge[i].nex)
{
//搜索 DFS(to,now);
//状态转移
}
//信息处理
}
P1352 没有上司的舞会
入门题
作为最简单的树形DP了吧这算是,理顺清楚子节点和父节点的关系去做DP即可
#include <iostream>
#include <bits/stdc++.h>
using namespace std;
const int N = 6e3+100;
struct node
{
int nex,to;
}edge[N<<1];
int tot,head[N];
void add(const int &from,const int &to)
{
edge[++tot].to=to;
edge[tot].nex=head[from];
head[from]=tot;
}
int dp[N][2],a[N];
void DFS(int now,int fa)
{
for(int i=head[now];i;i=edge[i].nex)
{
int to=edge[i].to;
if(to==fa)
continue;
DFS(to,now);
dp[now][0]=max({dp[to][1],dp[now][0]+dp[to][1],dp[now][0]+dp[to][0],dp[to][0],dp[now][0]});
dp[now][1]=max({dp[now][1],dp[now][1]+dp[to][0]});
}
}
int main()
{
int n;
cin>>n;
for(int i=1;i<=n;i++)
{
cin>>dp[i][1];
}
for(int i=1,a,b;i<n;i++)
{
cin>>a>>b;
add(a,b);
add(b,a);
}
DFS(1,0);
cout<<max(dp[1][0],dp[1][1])<<endl;
return 0;
}
H. Crystalfly
题意:每个节点都有数量不同的萤火虫(我正好在<听夜,萤火虫和你>)并且有一个t值,我们跑到一个节点的时候他的所有儿子节点都会被激活,被激活t秒后此节点的萤火虫全部飞走。t取值为:1 2 3
这题哪来的目的
1.树形dp有时不一定是DP的问题,而是后序信息处理上不好写,我们所说的:配合DFS来写出dp的转移,也正是如此。
#include <bits/stdc++.h>
using ll = long long;
using namespace std;
int n;
const int N = 1e5 + 10, M = N * 2;
int e[M], ne[M], h[N], idx;
ll t[N], f[N], g[N], a[N];
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
void dfs(int u, int fa) {
g[u] = a[u];
ll maxn1 = 0, maxn2 = 0;
for(int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if(j == fa) continue;
dfs(j, u);
g[u] += f[j] - a[j]; // 拿 u 但是不拿 u 的直接子节点
ll temp = g[j] - (f[j] - a[j]);
if(maxn1 < temp) {
maxn2 = maxn1;
maxn1 = temp;
} else if(maxn2 < temp) maxn2 = temp;
}
f[u] = g[u]; //删掉后叶子节点无法更新
for(int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if(j == fa) continue;
f[u] = max(f[u], g[u] + a[j]);
if(t[j] == 3) {
if(g[j] - f[j] + a[j] == maxn1) f[u] = max(f[u], g[u] + a[j] + maxn2);
else f[u] = max(f[u], g[u] + a[j] + maxn1);
}
}
}
void solve() {
cin >> n;
memset(h, -1, sizeof h);
idx = 0;
for(int i = 1; i <= n; i ++) cin >> a[i], f[i] = g[i] = 0;
for(int i = 1; i <= n; i ++) cin >> t[i];
for(int i = 1; i < n; i ++) {
int a, b;
cin >> a >> b;
add(a, b), add(b, a);
}
dfs(1, -1);
cout << f[1] << '\n';
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
#ifndef ONLINE_JUDGE
freopen("D:/Cpp/program/Test.in", "r", stdin);
freopen("D:/Cpp/program/Test.out", "w", stdout);
#endif
int t;
cin >> t;
while(t --) solve();
}