题意给定n个点的树,每个点有一个点权wi, 每次选一个点u,则树上u和距离u wi范围内的所有点都会被染色。
问:最少选几个点使得n个点都被染色。
思路:树形DP。这道题无论是从思路还是实现都有点难度。
首先用dp[i][j]表示以i为根的子树的点全被染色且可以从i点向上再染j个距离的最小值,
用f[i]][j]表示以i为根的子树在距离i为j处的有子孙节点未被染色且距离超过j的子孙节点均被染色的最小值。
那么可以得到三个状态转移方程:
1.i点不染色。
dp[i][j] = min( dp[u][j+1] + sigma( min( dp[v][0,1,2...j+1], f[v][0,1,2...j-1] ) ) )
f[i][j] = min( f[u][j-1] + sigma( min( dp[v][0,1,2...j], f[v][0,1,2...j-1] ) ) )
2.i点染色
dp[i][w[i]] = min( dp[i][w[i]], sigma(dp[u][0,1,2...w[i]+1], f[u][0,1...w[i]-1])+1 )
其中u,v为i的儿子节点。
这样看起来还是很麻烦,我们稍微化简一下,
用数组mindp[i][j]表示dp[i][0]到dp[i][j]的最小值,
用数组minf[i][j]表示f[i][0]到f[i][j]的最小值,
那么原来的状态转移方程可以改写为(以第一个为例)
dp[i][j] = min( dp[u][j+1] + sigma( min( mindp[v][j+1],minf[v][j-1] ) ) )
再预处理一下,用sumdp[i][j]表示i结点儿子的sigma( min( mindp[v][j+1],minf[v][j-1] ))值,sumf类似
于是就可以以O(n*100)的复杂度得出答案。
#include<cstdio>
#include<cstring>
#include<cmath>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#include<vector>
#include<map>
#include<queue>
#include<stack>
#include<string>
#include<map>
#include<set>
#include<ctime>
#define eps 1e-6
#define LL long long
#define pii pair<int, int>
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;
const int MAXN = 100100;
const int INF = 1000000;
int n, w[MAXN];
int dp[MAXN][102], f[MAXN][102];
int mindp[MAXN][102], minf[MAXN][102], sumdp[MAXN][102], sumf[MAXN][102];
vector<int> G[MAXN];
void dfs(int cur, int fa) {
for(int i = 0; i <= 101; i++) dp[cur][i] = f[cur][i] = mindp[cur][i] = minf[cur][i] = INF;
int sz = G[cur].size();
for(int i = 0; i < sz; i++) {
int u = G[cur][i];
if(u == fa) continue;
dfs(u, cur);
sumdp[cur][0] += mindp[u][1];
sumf[cur][0] += mindp[u][0];
for(int j = 1; j <= 100; j++) {
sumdp[cur][j] += min(mindp[u][j+1], minf[u][j-1]);
sumf[cur][j] += min(minf[u][j-1], mindp[u][j]);
}
}
f[cur][0] = sumf[cur][0];
for(int i = 0; i < sz; i++) {
int u = G[cur][i];
if(u == fa) continue;
dp[cur][0] = min(dp[cur][0], dp[u][1]+sumdp[cur][0]-mindp[u][1]);
for(int j = 1; j <= 100; j++) {
dp[cur][j] = min(dp[cur][j], dp[u][j+1]+sumdp[cur][j]-min(mindp[u][j+1], minf[u][j-1]));
f[cur][j] = min(f[cur][j], f[u][j-1]+sumf[cur][j]-min(minf[u][j-1], mindp[u][j]));
}
}
dp[cur][w[cur]] = min(dp[cur][w[cur]], sumdp[cur][w[cur]]+1);
mindp[cur][0] = dp[cur][0];
minf[cur][0] = f[cur][0];
for(int i = 1; i <= 100; i++) {
mindp[cur][i] = min(dp[cur][i], mindp[cur][i-1]);
minf[cur][i] = min(f[cur][i], minf[cur][i-1]);
}
}
void init() {
memset(sumdp, 0, sizeof(sumdp));
memset(sumf, 0, sizeof(sumf));
for(int i = 1; i <= n; i++) G[i].clear();
}
int main() {
//freopen("input.txt", "r", stdin);
while(cin >> n) {
init();
for(int i = 1; i <= n; i++) scanf("%d", &w[i]);
for(int i = 1, u, v; i < n; i++) {
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1, 0);
int ans = INF;
for(int i = 0; i <= 100; i++) ans = min(ans, dp[1][i]);
cout << ans << endl;
}
return 0;
}