思路来源
https://www.cnblogs.com/stoorz/p/12388534.html 板子
https://www.cnblogs.com/zwfymqz/p/9175152.html 构建过程、复杂度证明
知识点整理
以下部分来自https://www.cnblogs.com/zwfymqz/p/9175152.html
维护一个栈,建虚树的时候分三种情况,一边dfs一边建即可
一般考察虚树dp,复杂度是O(2*sumk),因为加入一个点最多多一个lca
板子整理
把0号点当做一个默认出现的虚点,这样统一了很多情况
key[i]=1的点是虚树标记的点,par关系是新建的虚树
#include<bits/stdc++.h>
using namespace std;
const int N=200010,LG=20;
int n,m;
vector<int>e[N]
int stk[N],top;//虚树栈
int dep[N],f[N][LG+1];//lca
int par[N];//虚树
bool key[N];//关键点
int lca(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=LG;i>=0;i--){
if(dep[f[x][i]]>=dep[y]){
x=f[x][i];
}
}
if(x==y)return x;
for(int i=LG;i>=0;i--){
if(f[x][i]!=f[y][i]){
x=f[x][i],y=f[y][i];
}
}
return f[x][0];
}
void dfs(int u,int fa){
dep[u]=dep[fa]+1;
f[u][0]=fa;
for(int i=1;i<=LG;i++){
f[u][i]=f[f[u][i-1]][i-1];
}
if(key[u]){
int p=lca(u,stk[top]);
if(p!=stk[top]){
while(dep[stk[top-1]]>dep[p]){
par[stk[top]]=stk[top-1];
top--;
}
par[stk[top]]=p;
top--;
if(stk[top]!=p){
stk[++top]=p;
}
}
stk[++top]=u;
}
for(int v:e[u]){
if(v==fa)continue;
dfs(v,u);
}
}
void build(){//以0为虚根
top=1;
dfs(1,0);
for(int i=top;i>=1;i--){
par[stk[i]]=stk[i-1];
}
}
int main(){
return 0;
}
支持多组询问的虚树(2024.2.11补充)
先dfs一遍,得到dfs序,
对于单组询问的询问点,先按dfn增序排序,
然后套用建一次虚树的方式,此时虚根仍然是0
以abc340g为例,也可参考洛谷p2495【消耗战】
KAJIMA CORPORATION CONTEST 2024(AtCoder Beginner Contest 340) G. Leaf Color(虚树+dp)-优快云博客
#include<bits/stdc++.h>
#include<iostream>
#include<cstdio>
#include<vector>
#include<map>
#include<queue>
using namespace std;
typedef array<int,2> a2;
typedef array<int,3> a3;
const int N=2e5+10,LG=20,mod=998244353;
int n,u,v,ans,a[N];
vector<int>col[N],e[N],g[N];
int dep[N],f[N][LG+1],dfn[N],c;
bool key[N];//关键点
int lca(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=LG;i>=0;i--){
if(dep[f[x][i]]>=dep[y]){
x=f[x][i];
}
}
if(x==y)return x;
for(int i=LG;i>=0;i--){
if(f[x][i]!=f[y][i]){
x=f[x][i],y=f[y][i];
}
}
//printf("x:%d y:%d lca:%d\n",x,y,f[x][0]);
return f[x][0];
}
void dfs(int u,int fa){
dep[u]=dep[fa]+1;
dfn[u]=++c;
//printf("u:%d dfn:%d\n",u,dfn[u]);
f[u][0]=fa;
for(int i=1;i<=LG;i++){
f[u][i]=f[f[u][i-1]][i-1];
}
for(auto &v:e[u]){
if(v==fa)continue;
dfs(v,u);
}
}
void add(int x,int y){
//if(x==y)return;
//printf("x:%d y:%d\n",x,y);
g[x].push_back(y);
}
a2 dfs1(int u){
//printf("u:%d\n",u);
a2 dp={1,1};
a3 h={1,0,0};
for(auto &v:g[u]){
auto dp2=dfs1(v);
dp[0]=1ll*dp[0]*dp2[0]%mod;
dp[1]=1ll*dp[1]*(dp2[1]+dp2[0])%mod;
h[2]=(1ll*h[2]*(dp2[0]+dp2[1])%mod+1ll*h[1]*dp2[1]%mod)%mod;
h[1]=(1ll*h[1]*dp2[0]%mod+1ll*h[0]*dp2[1]%mod)%mod;
h[0]=1ll*h[0]*dp2[0]%mod;
}
if(key[u]){
ans=(ans+dp[1])%mod;
}
else{
ans=(ans+h[2])%mod;
dp[1]=(dp[1]+mod-1)%mod;//非关键点的不能当叶子
}
key[u]=0;
g[u].clear();
return dp;
}
void build(vector<int>&a){//以0为虚根
static int stk[N],top;//虚树栈
sort(a.begin(),a.end(),[&](int x,int y){
return dfn[x]<dfn[y];
});
int sz=a.size();
stk[top=1]=a[0];
key[a[0]]=1;// 标记为关键点
for(int i=1;i<sz;++i){
int u=a[i];
key[u]=1;// 标记为关键点
int p=lca(u,stk[top]);
if(p!=stk[top]){
while(dep[stk[top-1]]>dep[p]){
add(stk[top-1],stk[top]);
top--;
}
add(p,stk[top]);
top--;
if(stk[top]!=p){
stk[++top]=p;
}
}
stk[++top]=u;
}
for(int i=top;i>=2;i--){
add(stk[i-1],stk[i]);
}
dfs1(stk[1]);
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;++i){
scanf("%d",&a[i]);
col[a[i]].push_back(i);
}
for(int i=1;i<n;++i){
scanf("%d%d",&u,&v);
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1,0);
for(int i=1;i<=n;++i){
if(!col[i].size())continue;
build(col[i]);
//puts("");
}
printf("%d\n",ans);
return 0;
}