题目描述:http://hihocoder.com/problemset/problem/1063
这道题用到树形dp来解,具体有两种方式来实现
先看第一种:
参考文章:http://blog.youkuaiyun.com/wsjingping/article/details/45875867
dp[i][j]表示以i为根的子树中,恰好得到j价值所走的最短距离,但是最后可以不回到起点,故有两种状态:dp[i][j][0]表示必须回到节点i,dp[i][j][1]表示可以不回来
个人认为,对于dp[i][j],不回来肯定小于等于回来,若回来小于不回来,那么再减去回来所花费的距离,将得到更小的不回来的距离,矛盾
#include <stdio.h>
#include <string.h>
#define MAX_N 100
#define INF (1<<31)-1
int dp[MAX_N+1][201][2], va[MAX_N+1], fe[MAX_N+1], n; //fe[]记录每个节点的第一个子节点
typedef struct Node{
int to, len, next; //next记录下一个兄弟节点
}Edge;
Edge e[MAX_N];
int min(int a, int b){
return a < b ? a : b;
}
void Add(int a, int b, int c, int ecount){
e[ecount].to = b;
e[ecount].len = c;
e[ecount].next = fe[a];
fe[a] = ecount;
}
void dfs(int u, int fa){
int i, ii, j, k, v, vv, l, ll, temp[201];
dp[u][va[u]][0] = 0;
for(i = fe[u]; i+1; i = e[i].next){
v = e[i].to;
l = e[i].len;
if(v != fa){
dfs(v, u);
for(j = n*2; j >= 0; j--){
if(dp[u][j][0] != INF){
for(k = 0; k <= n*2-j; k++){
if(dp[v][k][0] != INF){
dp[u][j+k][0] = min(dp[u][j+k][0], dp[u][j][0]+dp[v][k][0]+2*l);
}
}
}
}
}
}
dp[u][va[u]][1] = 0;
for(ii = fe[u]; ii+1; ii = e[ii].next){
vv = e[ii].to;
ll = e[ii].len;
if(vv != fa){
for(i = 0; i <= n*2; i++){
temp[i] = INF;
}
temp[va[u]] = 0;
for(i = fe[u]; i+1; i = e[i].next){
v = e[i].to;
l = e[i].len;
if(v != fa && v != vv){
for(j = n*2; j >= 0; j--){
if(temp[j] != INF){
for(k = 0; k <= n*2-j; k++){
if(dp[v][k][0] != INF){
temp[j+k] = min(temp[j+k], temp[j]+dp[v][k][0]+2*l);
}
}
}
}
}
}
for(i = n*2; i >= 0; i--){
if(temp[i] != INF){
for(j = 0; j <= n*2-i; j++){
if(dp[vv][j][1] != INF){
dp[u][i+j][1] = min(dp[u][i+j][1], temp[i] + dp[vv][j][1] + ll);
}
}
}
}
}
}
}
int main(){
int i, j, ecount, a, b, c, q, d;
ecount = 0;
memset(fe, -1, sizeof(fe));
scanf("%d", &n);
for(i = 1; i <= n; i++){
scanf("%d", &va[i]);
}
for(i = 1; i < n; i++){
scanf("%d%d%d", &a, &b, &c);
Add(a, b, c, ecount);
ecount++;
Add(b, a, c, ecount);
ecount++;
}
for(i = 1; i <= n; i++){
for(j = 0; j <= n*2; j++){
dp[i][j][0] = INF;
dp[i][j][1] = INF;
}
}
dfs(1, -1);
scanf("%d", &q);
while(q--){
scanf("%d", &d);
for(i = n*2; i >= 0; i--){
if(dp[1][i][1] <= d){
break;
}
}
printf("%d\n", i);
}
return 0;
}
方法二比较机智