题意:
给定 n
个节点的树,每个节点的值 为 0
或 1
。
现在需要你对树的 每个节点 v
求出:包含 v
的联通子图中,节点 1
的数量减去 0
的数量,最多 是多少。
思路:
虽然题设是以 联通子图 为对象,但是为了方便后续对 联通子图 进行 分集合讨论,我们还是可以考虑先在 第一遍 dfs1
的时候,以 子树 为对象 进行讨论。
设置 dp
数组,dp[i]
表示以 i
为根的子树中,Count(1) - Count(0)
的最大值。
例如 根节点是 u
,我们可以 进行 dp
,假设 u
的权值为 1
,它的 dp[u]
值等于它所有子树 v
中,所有为正数的 dp[v]
的累加和。也就是:
当 u
的权值为 -1
,也是类似的,只不过要 改为 -(a[u] == 0)
,因为贡献为 负,代码中有体现。
void dfs1(int u, int f)
{
if (a[u] == 1) dp[u] = 1;
else dp[u] = -1;
for (auto v : g[u])
{
if (v == f) continue;
dfs1(v, u);
dp[u] += max(0ll, dp[v]);
}
}
注意,这里是 从下往上的,从 叶子节点到根节点。
假设我们先 以 1
号节点为根,dp
了一遍,dp[i]
表示的是节点 i
的子树中,Count(1) - Count(0)
的最大值。现在看这个图,我们开始分析 第二遍 dfs2
假设现在需要计算 包含 u = 5
号节点的最大差值。那么我们可以把这棵树 分成两个部分:
我们将 f[u]
定义为:包含 节点 u
的联通子图最大差值。
现在 计算 f[5]
,用 x、y
标出了这两个部分。
我们发现 y
部分,就 等于 dp[5]
,那么 dp[5] > 0
就累加进 f[5]
。
再来看看 x
部分:
- 求
x
部分,可不可以 直接用dp[1] - dp[5]
?不可以,因为你不知道y
这一部分有没有贡献到f[1]
中,即使知道f[5]
的正负。比如2
和3
,它们是1
的儿子,那么dp[2]
和dp[3]
可以确定是否贡献到dp[1]
中,也就是只要 儿子的dp[]
值是 正数,那么 可以贡献到父节点中。而5
和1
的关系是孙子。孙子是 不一定 能贡献到爷爷的。举个例子,比如2
和5
节点 之间有100
个值为0
的,那么y
部分 就算 全部是1
也不可能 贡献进dp[1]
所以我们算 x
部分 的时候,应该 从 y
部分的根节点考虑。也就是 从 5
号节点考虑。我们只能知道 5
号节点能不能给他的父亲 2
号节点贡献。考虑到这里,就很容易算了,用 f[2] - max(dp[5], 0)
即可。也就是 如果 dp[5] > 0
那么它是 贡献进了 f[2]
的,那么就需要 减去。否则 没有贡献进去,就 不需要减去。
所以 f
的递推式 是这样的:
其中,x = f[fa[u]] - max(dp[u], 0)
void dfs2(int u, int fa)
{
for (auto v : g[u])
{
if (v == fa) continue;
int x = f[u] - max(0ll, dp[v]);
f[v] = dp[v] + max(0ll, x);
dfs2(v, u);
}
}
这里注意一点,对于 x
部分 要判断 正负 以决定 是否加入贡献。而 y
部分,不需要考虑正负,直接 贡献。因为,我们 求的 f[u]
一定是包含 u
的,而 dp[u]
就刚好包含 u
。
时间复杂度:
O ( n ) O(n) O(n)
代码:
#include <bits/stdc++.h>
using namespace std;
//#define map unordered_map
#define int long long
const int N = 2e5 + 10;
int n;
int a[N];
vector<int> g[N];
int dp[N], f[N];
void dfs1(int u, int f)
{
if (a[u] == 1) dp[u] = 1;
else dp[u] = -1;
for (auto v : g[u])
{
if (v == f) continue;
dfs1(v, u);
dp[u] += max(0ll, dp[v]);
}
}
void dfs2(int u, int fa)
{
for (auto v : g[u])
{
if (v == fa) continue;
int x = f[u] - max(0ll, dp[v]);
f[v] = dp[v] + max(0ll, x);
dfs2(v, u);
}
}
signed main()
{
cin >> n;
for (int i = 1; i <= n; ++i)
{
scanf("%lld", &a[i]);
}
int t = n - 1;
while (t--)
{
int x, y; scanf("%lld%lld", &x, &y);
g[x].emplace_back(y);
g[y].emplace_back(x);
}
dfs1(1, 0);
f[1] = dp[1];
dfs2(1, 0);
for (int i = 1; i <= n; ++i)
{
printf("%lld ", f[i]);
}
puts("");
return 0;
}