虚树是什么
虚树指在原树上选择需要的点和它们的 L C A LCA LCA组成的一棵树。这样可以使在树DP时顶点数更少,从而减少时间复杂度。一般用于有多组数据且能保证所有数据访问的点的和不超过规定范围。
情景代入:SDOI2011消耗战
题目大意
给出一棵树,根节点为一号点,有 n n n顶点, n − 1 n-1 n−1条边,每条边都有边权,断掉一条边的代价为这条边的边权。有 m m m次询问,每次询问给出 k k k个询问点,问使这 k k k个点都不和根节点相连的最小代价。
数据范围
1 ≤ n ≤ 2.5 × 1 0 5 , 1 ≤ m ≤ 5 × 1 0 5 , 1 ≤ ∑ k ≤ 5 × 1 0 5 1\leq n\leq 2.5\times 10^5,1\leq m\leq 5\times 10^5,1\leq \sum k\leq 5\times 10^5 1≤n≤2.5×105,1≤m≤5×105,1≤∑k≤5×105
做法
我们可以用树型DP。设 f [ i ] f[i] f[i]表示子树 i i i与根节点断开的代价, m d [ i ] md[i] md[i]表示点 i i i到根节点的最小边权。分类讨论一下:
- 如果点 i i i是询问点,那么 f [ i ] = m d [ i ] f[i]=md[i] f[i]=md[i]
- 如果点 i i i不是询问点,那么 f [ i ] = m i n ( m d [ i ] , ∑ j ∈ s o n i f j ) f[i]=min(md[i],\sum\limits_{j\in son_i}f_j) f[i]=min(md[i],j∈soni∑fj)
可以用dfs来解决。
但如果直接这样做的话,时间复杂度为 O ( n m ) O(nm) O(nm),显然会TLE。又因为 k k k的和在 5 × 1 0 5 5\times 10^5 5×105以内,所以我们可以用虚树来解决。
对于每次询问,我们将询问点和它们的 L C A LCA LCA放到虚树中。举几个例子:
对于如下一棵树
如果查询点为6,10,那么构成的虚树如下
放进虚树的点即为查询点和它们的
L
C
A
LCA
LCA。
因为每加入一个点最多只会产生一个 L C A LCA LCA,所以如果有 k k k个有效的点,则虚树上最多只会有 2 k 2k 2k个点。
虚树如何建立
那么,虚树该如何建立呢?
首先,我们对原树进行dfs,按dfs序给每一个点打上时间戳dfn。
将所有要查询的点按dfn排序,用栈来维护根节点到当前点的链。
一开始,根节点入栈, s t [ + + t o p ] = 1 st[++top]=1 st[++top]=1
设当前加入的点为 x x x
- 用 w h i l e while while循环,若 d f n [ s [ t o p − 1 ] ] ≥ d f n [ l c a ( s [ t o p ] , x ) ] dfn[s[top-1]]\geq dfn[lca(s[top],x)] dfn[s[top−1]]≥dfn[lca(s[top],x)],那么 l c a lca lca为点 s t [ t o p ] st[top] st[top]的祖先,连边 ( s t [ t o p − 1 ] , s t [ t o p ] ) , t o p − − (st[top-1],st[top]),top-- (st[top−1],st[top]),top−−
- 若 d f n [ l c a ( s t [ t o p ] , x ) ] ≠ d f n [ s t [ t o p ] ] dfn[lca(st[top],x)]\neq dfn[st[top]] dfn[lca(st[top],x)]=dfn[st[top]],则 l c a lca lca在点 s t [ t o p ] st[top] st[top]和 s t [ t o p − 1 ] st[top-1] st[top−1]之间,连边 ( l c a , s t [ t o p ] ) , s t [ t o p ] = l c a (lca,st[top]),st[top]=lca (lca,st[top]),st[top]=lca,将 x x x入栈,然后退出
当所有点都考虑完了之后,还要对栈中的点依次连边并退栈。
code
void insert(int x){
if(top==1){
s[++top]=x;return;
}
int lca=LCA(x,s[top]);
// if(lca==s[top]) return;
while(top>1&&dfn[s[top-1]]>=dfn[lca]){
add(s[top-1],s[top]);--top;
}
if(lca!=s[top]){
add(lca,s[top]);s[top]=lca;
}
s[++top]=x;
}
其中被注释的一行是一般虚树加点操作没有的,但这道题需要。因为这道题如果一个点一定不与根节点相连,则其子树一定满足条件,所以子树可以不用考虑。而在最底部的 s [ t o p ] s[top] s[top]一定不是 L C A LCA LCA,所以遇到这种情况直接return即可。
加点过程如下
code
dfs(1,0);
while(m--){
scanf("%d",&k);
for(int i=1;i<=k;i++){
scanf("%d",&a[i]);
}
sort(a+1,a+k+1,cmp);
s[top=1]=1;
for(int i=1;i<=k;i++){
insert(a[i]);
}
while(top>1){
v[s[top-1]].push_back(s[top]);--top;
}
printf("%lld\n",dp(1));
}
SDOI2011消耗战
用虚树来做的话,时间复杂度为 O ( ∑ k log k ) O(\sum k\log k) O(∑klogk)。
code
#include<bits/stdc++.h>
using namespace std;
const int N=250000;
int n,m,k,tot=0,dt=0,top,d[500005],l[500005],r[500005];
int a[N+5],s[N+5],fa[N+5],tp[N+5],dep[N+5],siz[N+5],son[N+5],dfn[N+5];
long long w[500005],md[N+5];
vector<int>v[N+5];
bool cmp(int ax,int bx){
return dfn[ax]<dfn[bx];
}
void add(int xx,int yy,long long zz){
l[++tot]=r[xx];d[tot]=yy;r[xx]=tot;w[tot]=zz;
}
void dfs1(int u,int f){
dep[u]=dep[f]+1;
fa[u]=f;siz[u]=1;
for(int i=r[u];i;i=l[i]){
if(d[i]==f) continue;
md[d[i]]=min(w[i],md[u]);
dfs1(d[i],u);
siz[u]+=siz[d[i]];
if(siz[d[i]]>siz[son[u]]) son[u]=d[i];
}
}
void dfs2(int u,int f){
dfn[u]=++dt;
if(son[u]){
tp[son[u]]=tp[u];
dfs2(son[u],u);
}
for(int i=r[u];i;i=l[i]){
if(d[i]==f||d[i]==son[u]) continue;
tp[d[i]]=d[i];
dfs2(d[i],u);
}
}
int gt(int x,int y){
while(tp[x]!=tp[y]){
if(dep[tp[x]]<dep[tp[y]]) swap(x,y);
x=fa[tp[x]];
}
if(dep[x]>dep[y]) swap(x,y);
return x;
}
void insert(int x){
if(top==1){
s[++top]=x;return;
}
int lca=gt(x,s[top]);
if(lca==s[top]) return;
while(top>1&&dfn[s[top-1]]>=dfn[lca]){
v[s[top-1]].push_back(s[top]);--top;
}
if(s[top]!=lca){
v[lca].push_back(s[top]);s[top]=lca;
}
s[++top]=x;
}
long long dp(int u){
if(v[u].size()==0) return md[u];
long long sum=0;
for(int i=0;i<v[u].size();i++){
sum+=dp(v[u][i]);
}
v[u].clear();
return min(sum,md[u]);
}
int main()
{
int x,y;
long long z;
scanf("%d",&n);
md[1]=1e18;
for(int i=1;i<n;i++){
scanf("%d%d%lld",&x,&y,&z);
add(x,y,z);add(y,x,z);
}
dfs1(1,0);tp[1]=1;
dfs2(1,0);
scanf("%d",&m);
while(m--){
scanf("%d",&k);
for(int i=1;i<=k;i++){
scanf("%d",&a[i]);
}
sort(a+1,a+k+1,cmp);
s[top=1]=1;
for(int i=1;i<=k;i++){
insert(a[i]);
}
while(top>1){
v[s[top-1]].push_back(s[top]);--top;
}
printf("%lld\n",dp(1));
}
return 0;
}