前言
在做树上动态规划时,可能有很多节点是不必要的。我们通过建立虚树,可以去除不必要的节点,优化时间复杂度。
虚树优化树形动态规划
通常的写法喜欢用栈建立虚树,但是不太好写,这里说一种快速建立虚树的方法:
把点按dfn排序,相邻点的lca丢进去,去重之后再dfn排序,每个点在虚树上的父亲就是自己和前驱的lca。
证明很显然。
然后就有了虚树优化树形动态规划
代码如下:
#include<iostream>
#include<vector>
#include<algorithm>
#include<cstring>
#include<climits>
using namespace std;
int n;
int dfn[250005],cnt;
vector<vector<pair<int,int>>> a;
vector<pair<int,int>> b[250005];
bool cmp(int a,int b) {
return dfn[a]<=dfn[b];
}
int f[20][250005],g[20][250005];
//g[k][i]表示从i节点向上走2^k条边,经过的最小边权
int *fa=f[0],size[250005],son[250005],top[250005],deep[250005];
int dfs1(int u) {
size[u]=1;
deep[u]=deep[fa[u]]+1;
for(auto&i:a[u]) {
int v=i.first,p=i.second;
if(v==fa[u]) continue;
fa[v]=u;
g[0][v]=p;
size[u]+=dfs1(v);
if(size[son[u]]<size[v]) son[u]=v;
}
return size[u];
}
void dfs2(int u) {
if(son[fa[u]]==u) top[u]=top[fa[u]];
else top[u]=u;
dfn[u]=++cnt;
if(son[u]) dfs2(son[u]);
for(auto&i:a[u]) {
int v=i.first;
if(v==fa[u]||v==son[u]) continue;
dfs2(v);
}
}
int LCA(int x,int y) {
if(deep[x]<deep[y]) swap(x,y);
for(int k=19;k>=0;k--)
if(deep[f[k][x]]>=deep[y])
x=f[k][x];
if(x==y) return x;
for(int k=19;k>=0;k--)
if(f[k][x]^f[k][y])
x=f[k][x],
y=f[k][y];
return fa[x];
}
int find(int x,int y) {
//deep[x]>deep[y]
int ans=INT_MAX;
for(int k=19;k>=0;k--)
if(deep[f[k][x]]>=deep[y])
ans=min(ans,g[k][x]),
x=f[k][x];
return ans;
}
long long h[250005];
bool w[250005];
long long dfs3(int u) {
h[u]=0;
if(w[u]) h[u]=1e14,w[u]=0;
for(auto&i:b[u]) {
int v=i.first,p=i.second;
h[u]+=min(dfs3(v),(long long)p);
}
return h[u];
}
int lca[250005];
int x[500005],num;
int main() {
for(auto&i:g)
for(auto&j:i)
j=INT_MAX;
cin>>n;
a.resize(n+1);
for(int i=1;i<n;i++) {
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
a[u].push_back({v,w});
a[v].push_back({u,w});
}
dfs1(1);
dfs2(1);
for(int k=1;k<20;k++)
for(int i=1;i<=n;i++)
f[k][i]=f[k-1][f[k-1][i]],
g[k][i]=min(g[k-1][i],g[k-1][f[k-1][i]]);
int m;
cin>>m;
while(m--) {
num=0;
x[num++]=1;
int K;
scanf("%d",&K);
for(int i=1;i<=K;i++) {
int y;
scanf("%d",&y);
w[y]=1;
x[num++]=y;
}
sort(x,x+num,cmp);
// cout<<"***";
// for(auto&i:x) cout<<i<<' ';
// cout<<"***"<<endl;
for(int i=1;i<=K;i++) {
int L=LCA(x[i],x[i-1]);
if(L^x[i]&&L^x[i-1])
x[num++]=L;
}
stable_sort(x,x+num,cmp);
num=unique(x,x+num)-x;
//按dfs序排序
// cout<<"***";
// for(auto&i:x) cout<<i<<' ';
// cout<<"***"<<endl;
for(int i=1;i<num;i++) {
lca[i]=LCA(x[i],x[i-1]);
int p=find(x[i],lca[i]);
b[lca[i]].push_back({x[i],p});
}
printf("%lld\n",dfs3(1));
for(int i=1;i<num;i++)
if(b[lca[i]].size())
b[lca[i]].clear();
}
return 0;
}
- 注意到x的空间最多开到2n
- 建虚树的时候其实并没有“去重之后再排序”,而是先根据dfn排序,然后再去重。注意这里用的是归并排序(stable_sort),因为如果询问是一条链的话,数组会趋近于有序,sort就被卡成O(n2)了,就会TLE。当然实测sort也不是卡过不去,具体的写法是把两句去重换成三句:
sort(x,x+num);//按普通顺序排序,去重
num=unique(x,x+num)-x;
sort(x,x+num,cmp);//按dfs序排序
不过这种写法其实能过比较巧合,因为按照普通顺序sort没有被卡,如果被卡了,这种方法也是有可能过不去的。
- 注意清空b数组的时候不能直接枚举全部清空,这样复杂度不对,应该只清空使用过的地方,然后注意不要使用b[lca[i]].resize(0)来清空,因为这个函数比clear()慢很多。(也能过去就是了)
后记
于是皆大欢喜。