题目大意:
就是一个有N个点的树现在要从这棵树上选出K个点使得这K个点两两之间没有祖先关系, 即任意一个都不是另外一个的祖先, 那么选出K个这样的点能得到的最大的权值和是多少
大致思路:
首先这题如果K不大的话可以用树形DP直接弄, 不过这个题K比较大, 于是需要用一个巧妙的方法
先膜拜一下vawait....
好了进入正题
首先处理出每个节点u, 和其子树中所有节点的权值中的最大值w[u]
然后处理出对于每个节点u, 从其子树中选出两个没有祖先关系的点的权值和的最大值(不包括节点u自己) f[u]
那么用优先队列维护一下,
初始的时候加入(w[root], root)
然后每次选择当前最大的出队(w[u], u), 如果是用过的(即u及其子树中选择了最大值w[u])那么就出队, 取消这个点的选择状态, 将所有u的儿子(w[u], u)入队
否则就选择这个点, 然后将(f[u] - w[u], u)入队
这样保证了当u有两个子节点权值和比w[u](u及其子树中任意一个结点权值的最大值)大的时候, 会放弃选择u而转向选择u的子节点
代码如下:
Result : Accepted Memory : 7428 KB Time : 608 ms
/*
* Author: Gatevin
* Created Time: 2015/8/10 20:10:18
* File Name: Sakura_Chiyo.cpp
*/
#include<iostream>
#include<sstream>
#include<fstream>
#include<vector>
#include<list>
#include<deque>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<bitset>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cctype>
#include<cmath>
#include<ctime>
#include<iomanip>
using namespace std;
const double eps(1e-8);
typedef long long lint;
#define maxn 100010
int N, K, root;
int w[maxn];
bool select[maxn];
vector<int> G[maxn];
int f[maxn];//f[u]表示从u的子孙节点中选两个权值和的最大值
void dfs(int now)//处理出每个子树u及其子树中最小的权值w[u], 以及其子树中两个没有关联的父亲关系的点权值最大和
{
int nex;
int mx = -1e9;//表示前几个子树中的最小w
for(int i = 0, sz = G[now].size(); i < sz; i++)
{
nex = G[now][i];
dfs(G[now][i]);
f[now] = max(f[now], max(f[nex], w[nex] + mx));
w[now] = max(w[now], w[nex]);
mx = max(mx, w[nex]);
}
return;
}
void solve()
{
memset(select, 0, sizeof(select));
priority_queue<pair<int, int> > Q;
Q.push(make_pair(w[root], root));
int ret = 0;
while(K)
{
if(Q.empty())//选不出K个
{
puts("0");
return;
}
int u = Q.top().second;
Q.pop();
if(select[u])//说明现在选择u的子节点中的两个要更优
{
K++;
select[u] = 0;//放弃选择子树的根节点u
ret -= w[u];
for(int i = 0, sz = G[u].size(); i < sz; i++) Q.push(make_pair(w[G[u][i]], G[u][i]));//加入其儿子节点
}
else
{
K--;
select[u] = 1;
ret += w[u];//选择了u结点
Q.push(make_pair(f[u] - w[u], u));//那么当选择两个其子节点和和它的差值要加入队列, 以确定是否要放弃这个结点
}
}
printf("%d\n", ret);
return;
}
int main()
{
while(scanf("%d %d", &N, &K), N || K)
{
for(int i = 1; i <= N; i++) G[i].clear();
int p;
for(int i = 1; i <= N; i++)
{
f[i] = -1e9;
scanf("%d %d", &p, w + i);
if(p == 0) root = i;
else G[p].push_back(i);
}
dfs(root);
solve();
}
return 0;
}