虚树学习总结
前言:
终于搞掉了虚树这个大坑······
内容:
给定树中的关键点,为关键点建树从而将复杂度从和总点数相关降低到只与关键点数相关。
做法与实现:
所谓虚树,就是将不必要的结点省略掉,只留下需要的节点,假设题目要求我们涉及k个关键点,那么如果只需要这k个关键点之间的信息,实际上我们只需要维护k个关键点以及他们的lca即可,考虑到就算是关键点的lca,个数也是平方级的,如何优化呢,我们先将所有的关键点按照dfs序排序,对于三个在dfs序从小到大的点x,y, z显然lca(x,z)一定是lca(x,y), lca(y, z)中的一个,(显然k个点的lca一定只有k - 1个),那么现在我们将点数降到了O(k),考虑如何建树,我们维护一个最右链的栈,现在我们考虑新加入一个点,设这个点和当前栈顶的点的lca为pre,如果pre不为栈顶,那么显然的栈顶的子树中的关键点已经搞完了,那么我们就可以把栈顶弹掉,然后找到上一个点,如果依然深度小于pre继续重复上一个操作,知道我们找到一个点它的dep是小于等于pre的,现在的情况就有两种,第一这个点和pre相等,那么直接跳出即可,否则这个点一定是pre的祖先,上一个弹掉的点一定是pre的子树内的点,我们将pre和上一个弹掉的点连边,然后将pre入栈。最后将当前的新加入的点放入即可。建树过程非常想treap维护最右链的思想,可以结合理解一下。
注意:如果在建树过程中不判断是否和当前的pre和a[i]相等一般不会造成什么问题,但是会出现一些自环,如果不管的话,也可以在的代码中处理一下即可,还有一定要注意清空不能memset,只能清使用过的地方,边集可以在dfs最后清。
(PS:一般情况下,虚树的题,难点在DP······)
例题:消耗战
题目背景:
分析:虚树 + 树型DP
比较显然的肯定先把关键点提出来,然后有一个优化就是,如果存在一个点是另一个点的祖先那么只保留这个点就可以了,因为只断掉另一个点并无卵用,那么考虑具体如何DP,显然对于每一个点有两种方法,一:断掉所有的儿子,二:断掉自己,两者直接取min就是断掉自己的最优选择,而断掉自己的价格就是到根路径上的最小值,直接在过程中预处理即可。(叶节点只能断自己)
Source:
/*
created by scarlyw
*/
#include <cstdio>
#include <string>
#include <algorithm>
#include <cstring>
#include <iostream>
#include <cmath>
#include <cctype>
#include <vector>
#include <set>
#include <queue>
inline char read() {
static const int IN_LEN = 1024 * 1024;
static char buf[IN_LEN], *s, *t;
if (s == t) {
t = (s = buf) + fread(buf, 1, IN_LEN, stdin);
if (s == t) return -1;
}
return *s++;
}
///*
template<class T>
inline void R(T &x) {
static char c;
static bool iosig;
for (c = read(), iosig = false; !isdigit(c); c = read()) {
if (c == -1) return ;
if (c == '-') iosig = true;
}
for (x = 0; isdigit(c); c = read())
x = ((x << 2) + x << 1) + (c ^ '0');
if (iosig) x = -x;
}
//*/
const int OUT_LEN = 1024 * 1024;
char obuf[OUT_LEN], *oh = obuf;
inline void write_char(char c) {
if (oh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), oh = obuf;
*oh++ = c;
}
template<class T>
inline void W(T x) {
static int buf[30], cnt;
if (x == 0) write_char('0');
else {
if (x < 0) write_char('-'), x = -x;
for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48;
while (cnt) write_char(buf[cnt--]);
}
}
inline void flush() {
fwrite(obuf, 1, oh - obuf, stdout);
}
/*
template<class T>
inline void R(T &x) {
static char c;
static bool iosig;
for (c = getchar(), iosig = false; !isdigit(c); c = getchar())
if (c == '-') iosig = true;
for (x = 0; isdigit(c); c = getchar())
x = ((x << 2) + x << 1) + (c ^ '0');
if (iosig) x = -x;
}
//*/
const int MAXN = 250000 + 10;
int n, m, x, y, z, ind;
int dfn[MAXN], father[MAXN], size[MAXN], dep[MAXN];
int son[MAXN], top[MAXN];
long long val[MAXN];
struct node {
int to, w;
node(int to = 0, int w = 0) : to(to), w(w) {}
} ;
std::vector<node> edge[MAXN], new_edge[MAXN];
inline void add_edge(int x, int y, int z) {
edge[x].push_back(node(y, z));
edge[y].push_back(node(x, z));
}
inline void add_new_edge(int x, int y) {
new_edge[x].push_back(node(y));
}
inline void read_in() {
R(n);
for (int i = 1; i < n; ++i) R(x), R(y), R(z), add_edge(x, y, z);
}
inline void dfs1(int cur, int fa) {
dfn[cur] = ++ind, father[cur] = fa, size[cur] = 1, dep[cur] = dep[fa] + 1;
for (int p = 0; p < edge[cur].size(); ++p) {
node *e = &edge[cur][p];
if (e->to != fa) {
val[e->to] = std::min(val[cur], (long long)e->w);
dfs1(e->to, cur), size[cur] += size[e->to];
if (size[e->to] > size[son[cur]]) son[cur] = e->to;
}
}
}
inline void dfs2(int cur, int tp) {
top[cur] = tp;
if (son[cur]) dfs2(son[cur], tp);
for (int p = 0; p < edge[cur].size(); ++p) {
node *e = &edge[cur][p];
if (e->to != father[cur] && e->to != son[cur]) dfs2(e->to, e->to);
}
}
inline int lca(int u, int v) {
while (top[u] != top[v])
dep[top[u]] > dep[top[v]] ? u = father[top[u]] : v = father[top[v]];
return dep[u] > dep[v] ? v : u;
}
inline bool comp(const int &a, const int &b) {
return dfn[a] < dfn[b];
}
inline long long dfs(int cur) {
if (new_edge[cur].size() == 0) return val[cur];
long long ans = 0;
for (int p = 0; p < new_edge[cur].size(); ++p) {
node *e = &new_edge[cur][p];
ans += dfs(e->to);
}
return new_edge[cur].clear(), std::min(ans, (long long)val[cur]);
}
inline void solve() {
static int top, cnt, n;
static int stack[MAXN], a[MAXN];
R(cnt), top = 1, n = 1;
for (int i = 1; i <= cnt; ++i) R(a[i]);
std::sort(a + 1, a + cnt + 1, comp);
for (int i = 2; i <= cnt; ++i)
if (lca(a[i], a[n]) != a[n]) a[++n] = a[i];
stack[top] = 1;
for (int i = 1; i <= n; ++i) {
int pre = lca(a[i], stack[top]);
while (true) {
if (dep[stack[top - 1]] <= dep[pre]) {
if (pre == stack[top]) break ;
add_new_edge(pre, stack[top--]);
if (stack[top] != pre) stack[++top] = pre;
break ;
}
add_new_edge(stack[top - 1], stack[top]), top--;
}
if (stack[top] != a[i]) stack[++top] = a[i];
}
while (top > 1) add_new_edge(stack[top - 1], stack[top]), --top;
W(dfs(1)), write_char('\n');
}
int main() {
read_in();
val[1] = 1e18;
dfs1(1, 0);
dfs2(1, 1);
R(m);
while (m--) solve();
flush();
return 0;
}