看上去各种二分答案+贪心啊
不是很可做
只会 O(n^2) :每次找最深的点并向上跳 mid 步,将跳到的点燃,这样套个倍增,预先把点按深度排序的话 复杂度大概是 O(n^2*log^2) 的
看了一发题解依然看不懂,于是颓这题颓了一下午...
先引用 这里 的一句话:选择一些点代价相同的话一般是贪心,代价不同一般是 dp
此题比较麻烦的就是需要在 O(n) 的时间内验证答案
log 大概是不太可行了,没有什么骚操作预处理连点都扫不全
考虑 dfs 一遍,(以下说的点大部分都是需要覆盖的点
对于一个点的来说,它的子树中如果有一个没被覆盖的点距它距离 > mid 了,这肯定是不合法的,很容易想到我们每次判断最深未覆盖点的距离,若 = mid 则把当前点点燃
现在考虑距离小于 mid 的未覆盖点,我需要记录 x 的子树中距离 x 最近的点燃的点烧到 x 后还能再烧多长,这样可以让子树之间互相更新,还可以根据这个距离来更新 x
以上信息都可以自底向上更新,所以我们记录 dep[x] 表示子树中最深的未覆盖点的距离, rem[x] 表示子树中最浅的点燃的点烧到 x 后还能再烧多长,就可以进行上边的操作了
这里我把初值设为 -1 为了方便区分这个点是否是需要覆盖的点,能否向上更新信息
注意再更新到了根的时候,根的信息需要在 dfs 外额外判断,由于可能 dep[Root] < mid ,就是说点燃根的某个祖先来更新根的子树中的点,所以判断一下是否需要把根点燃即可
写了一个万恶的特判WA了好久...
代码:
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cctype>
#include<cstdio>
#include<cmath>
using namespace std;
const int MAXN = 300005;
struct EDGE{
int nxt, to;
EDGE(int NXT = 0, int TO = 0) {nxt = NXT; to = TO;}
}edge[MAXN << 1];
int n, m, totedge;
int head[MAXN], rem[MAXN], dep[MAXN];
bool has[MAXN], is[MAXN], dis[MAXN];
inline void add(int x, int y) {
edge[++totedge] = EDGE(head[x], y);
head[x] = totedge;
return;
}
void predfs(int x, int fa) {
for(int i = head[x]; i; i = edge[i].nxt) if(edge[i].to != fa) {
int y = edge[i].to;
predfs(y, x);
has[x] |= has[y];
}
return;
}
int dfs(int x, int fa, int mid) {
int tot = 0;
for(int i = head[x]; i; i = edge[i].nxt) if(edge[i].to != fa && has[edge[i].to]) {
int y = edge[i].to;
tot += dfs(y, x, mid);
if(dep[y] + 1 == mid && rem[x] != mid) {
rem[x] = mid;
++tot;
dep[y] = -1;
}
rem[x] = max(rem[x], rem[y] - 1);
}
for(int i = head[x]; i; i = edge[i].nxt) if(edge[i].to != fa && has[edge[i].to]) {
int y = edge[i].to;
if(dep[y] + 1 <= rem[x]) {
dep[y] = -1;
continue;
}
if(~dep[y]) dep[x] = max(dep[x], dep[y] + 1);
}
if(is[x]) dep[x] = max(dep[x], 0);
if(dep[x] == 0 && rem[x] >= 0) dep[x] = -1;
return tot;
}
inline bool chk(int mid) {
for(int i = 1; i <= n; ++i) rem[i] = dep[i] = -1;
int tmp = dfs(1, 0, mid);
if(~dep[1]) ++tmp;
return (tmp <= m);
}
inline void hfs(int l, int r) {
int mid = ((l + r) >> 1);
while(l < r) {
mid = ((l + r) >> 1);
if(chk(mid)) r = mid;
else l = mid + 1;
}
printf("%d\n", l);
}
int main() {
scanf("%d%d", &n, &m);
register int xx, yy = 0;
for(int i = 1; i <= n; ++i) {
scanf("%d", &xx);
yy += xx;
rem[i] = dep[i] = -1;
has[i] = is[i] = xx;
}
if(m >= yy) { //辣鸡特判毁我青春
puts("0");
return 0;
}
for(int i = 1; i < n; ++i) {
scanf("%d%d", &xx, &yy);
add(xx, yy); add(yy, xx);
}
predfs(1, 0);
hfs(1, (n - m) / m + 1);
return 0;
}