思路
- 使用 f [ u ] f[u] f[u]记录以 u u u为根节点的子树,取当前 u u u节点的蝴蝶数,所能得到的最优值。
- 使用 g [ u ] g[u] g[u]记录以 u u u为根节点的子树,不取当前 u u u节点的蝴蝶数,所能得到的最优值。
- 使用 h [ u ] h[u] h[u]记录以 u u u为根节点的子树,取走当前 u u u节点的蝴蝶数,然后返回 u u u个父节点,去取 t [ v ] = = 3 t[v]==3 t[v]==3节点的蝴蝶数。这里的 v v v是u的兄弟。
然后我们可以发现:
f
[
u
]
=
a
[
u
]
+
g
[
u
]
f[u]=a[u]+g[u]
f[u]=a[u]+g[u]
h
[
u
]
=
a
[
u
]
+
∑
v
∈
t
r
e
e
[
u
]
g
[
v
]
h[u]=a[u]+\sum_{v\in tree[u]}g[v]
h[u]=a[u]+v∈tree[u]∑g[v]
然后考虑如何计算
g
[
v
]
:
g[v]:
g[v]:
首先我们只考虑取走一个在到达
v
v
v后被惊扰的节点的情况,此时会得到
代码
#include <iostream>
#include <vector>
#include <cstring>
#include <set>
#include <cmath>
#include <unordered_map>
#include <algorithm>
using namespace std;
using pii = pair<int, int>;
const int N = 1e5 + 10;
long long a[N], T[N], f[N], g[N], h[N];
vector<int> tree[N];
void dfs(int u, int fa)
{
long long maxv = 0, fi = -1, se = -1, totalg = 0;
h[u] = a[u];
for (auto v : tree[u])
{
if (v == fa)
{
continue;
}
dfs(v, u);
if (T[v] == 3)
{
if (a[v] >= fi)
{
se = fi;
fi = a[v];
}
else
{
se = max(se, a[v]);
}
}
totalg += g[v];
h[u] += g[v];
maxv = max(maxv, a[v]);
}
g[u] = maxv + totalg;
if (fi != -1)
{
for (int i = 0; i < tree[u].size(); i++)
{
int v = tree[u][i];
if (v == fa)
{
continue;
}
if (fi == a[v] && T[v] == 3)
{
if (se == -1)
{
continue;
}
g[u] = max(g[u], h[v] + totalg - g[v] + se);
}
else
{
g[u] = max(g[u], h[v] + totalg - g[v] + fi);
}
}
}
f[u] = g[u] + a[u];
}
int main()
{
cin.tie(0), cout.tie(0);
ios::sync_with_stdio(false);
int t;
cin >> t;
while (t--)
{
int n;
cin >> n;
for (int i = 1; i <= n; i++)
{
f[i] = 0;
g[i] = 0;
h[i] = 0;
tree[i].clear();
}
for (int i = 1; i <= n; i++)
{
cin >> a[i];
}
for (int i = 1; i <= n; i++)
{
cin >> T[i];
}
for (int i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
tree[u].push_back(v);
tree[v].push_back(u);
}
dfs(1, 0);
cout << f[1] << endl;
}
return 0;
}