树形dp。因为最终答案要求的是差值,所以对于每个节点我们只需要考虑差值即可,dp[u][0]表示u节点以下(含u)中like的val比candle的val大的最大值,dp[u][1]表示candle的val值比like大的最大值。所以,对于每个节点我们就可以对于每个子树v:
dp[u][0] += dp[v][0];
dp[u][1] += dp[v][1];
但是对于原来已经翻转过的点,我们要无偿交换swap(dp[u][0], dp[u][1]);这点应该很明确(因为我们本来求的是没有翻转过的,现在由于这一点本来翻转过,导致我们本来求的东西都是相反的,所以要交换一下)。
这样的话就可以使得两个值都最大,然后再考虑翻转的情况,如果翻转的话,那么可以这样更新dp值。
dp[u][0] = max(dp[u][0], dp[u][1] - sub);
dp[u][1] = max(dp[u][1], dp[u][0] - sub);
sub表示翻转需要的代价。
然后需要注意的一点就是对于树根0,我们不能翻转,根据题目可以很容易看出来。
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<map>
#include<set>
#define REP(i, n) for(int i=0; i<n; i++)
#define CLR(a, b) memset(a, b, sizeof(a))
#define LL long long
using namespace std;
const int N = 50500;
struct Node
{
int v, s, p;
Node(){}
Node(int v, int s, int p):v(v), s(s), p(p){}
}node[N];
vector<int> E[N];
int x, y, n;
int dp[N][2];
void dfs(int u)
{
int sub;
dp[u][0] = node[u].v;
dp[u][1] = -node[u].v;
for(int i = 0; i < E[u].size(); i ++)
{
int v = E[u][i];
dfs(v);
dp[u][0] += dp[v][0];
dp[u][1] += dp[v][1];
}
if(node[u].s) sub = y;
else sub = x;
if(u)
{
if(node[u].s) swap(dp[u][0], dp[u][1]);
dp[u][0] = max(dp[u][0], dp[u][1] - sub);
dp[u][1] = max(dp[u][1], dp[u][0] - sub);
}
}
int main()
{
int v, s, p, f;
while(scanf("%d%d%d", &n, &x, &y) != EOF)
{
for(int i = 0; i <= n; i ++) E[i].clear();
for(int i = 1; i <= n; i ++)
{
scanf("%d%d%d%d", &v, &f, &s, &p);
if(p) v = -v;
node[i] = Node(v, s, p);
E[f].push_back(i);
}
dfs(0);
if(dp[0][0] < 0) puts("HAHAHAOMG");
else printf("%d\n", dp[0][0]);
}
}