这个版本是强行加了 1 节点进去的版本。
坑点:1.初始化不能初始化n个,得初始化k个,否则会T
2.虚树建边时建有向边比较好
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef LL lint;
const lint maxn = 300001;
const lint maxm = 600001;
const lint inf = 0x3f3f3f3f;
lint tot,he[maxn],ver[maxm],ne[maxm],cost[maxm];
void add( lint x,lint y,lint w ){
ver[++tot]=y;
ne[tot]=he[x];
he[x]=tot;
cost[tot]=w;
}
lint tot2,he2[maxn],ver2[maxm],ne2[maxm],cost2[maxm];
void add2( lint x,lint y,lint w ){
ver2[++tot2]=y;
ne2[tot2]=he2[x];
he2[x]=tot2;
cost2[tot2]=w;
}
lint vis[maxn];
void init(lint n){
for(lint i = 1;i<= n;i++) he2[i]=0,vis[i]=0;
tot2=1;
}
lint d[maxn],f[maxn][21],dfn[maxn],cnt = 0,f2[maxn][21];
queue<lint> que;
void dfs( lint x ,lint fa ){
dfn[x]=++cnt;
for( lint cure = he[x];cure;cure = ne[cure] ){
lint y = ver[cure];
if( y == fa ) continue;
f[y][0]=x;f2[y][0]=cost[cure];
d[y]=d[x]+1;
for( lint i = 1; (1<< i)<=d[y];i++ ){
f[y][i]=f[ f[y][i-1] ][i-1];
f2[y][i] = min( f2[y][i-1],f2[ f[y][i-1] ][i-1] );
}
dfs(y,x);
}
}
lint lca( lint x,lint y ){
if( d[x]<d[y] ) swap(x,y);
for( lint i = 20;i>=0;i-- ){
if( f[x][i] &&d[f[x][i]] >= d[y] ) x=f[x][i];
}
if(x==y) return x;
for( lint i = 20;i>=0;i-- ){
if( f[x][i]==f[y][i] ) continue;
x=f[x][i];y=f[y][i];
}
return f[x][0];
}
lint lca2( lint x,lint y ){
if( d[x]<d[y] ) swap(x,y);
lint res = inf;
for( lint i = 20;i>=0;i-- ){
if( f[x][i] && d[ f[x][i] ] >= d[y] ) {
res = min( res,f2[x][i] );
x = f[x][i];
}
}
return res;
}
bool cmp( lint x,lint y ){
return dfn[x]<dfn[y];
}
vector<lint> ve;
lint st[maxn],top;
LL dp( lint x,lint fa ){
LL res = 0;
bool flag = false;
for( lint cure = he2[x];cure;cure = ne2[cure] ){
lint y = ver2[cure];
flag = true;
lint cur = dp(y,x);
if(!vis[y]) res += min( (LL)cost2[cure],cur );
else res += cost2[cure];
}
he2[x]=0;
if(flag)return res;
return inf;
}
lint solve(){
top=0;
sort(ve.begin(),ve.end(),cmp);
if( ve[0] != 1 ) st[++top]=1;
for( lint i=0;i< ve.size();i++ ){
if(ve[i]==1)continue;
if(!top)st[++top]=ve[i];
else{
lint anc=lca(st[top],ve[i]);
while(top && d[anc] < d[st[top]]){
if( top == 1 || d[st[top-1]] < d[anc] ){
//add2( st[top],anc,lca2(st[top],anc) );
add2( anc,st[top],lca2( st[top],anc ) );
st[top]=anc;
break;
}
//add2( st[top],st[top-1],lca2(st[top],st[top-1]) );
add2( st[top-1],st[top],lca2(st[top],st[top-1]) );
top--;
}
st[++top]=ve[i];
}
}
while( top>1 ){
//add2( st[top],st[top-1],lca2(st[top],st[top-1]) );
add2( st[top-1],st[top],lca2(st[top],st[top-1]) );
top--;
}
return dp(1,0);
}
int main(){
lint n;
tot=1;
scanf("%lld",&n);
for( lint i = 1;i <= n-1;i++ ){
lint x,y,w;
scanf("%lld%lld%lld",&x,&y,&w);
add(x,y,w);add(y,x,w);
}
dfs(1,0);
lint m;
scanf("%lld",&m);
for( lint k,d,i = 1;i <=m;i++ ){
scanf("%lld",&k);
ve.clear();tot2 = 1;
for( lint j = 1;j<=k;j++ ){
scanf("%lld",&d);vis[d]=1;
ve.push_back(d);
}
LL ans= solve();
for( lint i = 0;i < ve.size();i++ ) vis[ve[i]]=0;
printf("%lld\n",ans);
}
return 0;
}
这是一个大佬写的不含1节点的版本
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cmath>
#include<cstring>
#include<algorithm>
#define F(i,j,n) for(int i=j;i<=n;i++)
#define D(i,j,n) for(int i=j;i>=n;i--)
#define ll long long
#define maxn 250005
#define inf 1000000000000000000ll
using namespace std;
int n,m,k,cnt,tot,top;
int head[maxn],head2[maxn],d[maxn],pos[maxn],p[maxn][25],a[maxn],s[maxn];
bool tag[maxn];
ll f[maxn],mn[maxn];
struct edge_type{int next,to,v;}e[maxn*2];
struct edge_type2{int next,to;}e2[maxn];
inline int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
inline void add_edge(int x,int y,int z)
{
e[++cnt]=(edge_type){head[x],y,z};head[x]=cnt;
e[++cnt]=(edge_type){head[y],x,z};head[y]=cnt;
}
inline void add_edge2(int x,int y)
{
e2[++cnt]=(edge_type2){head2[x],y};head2[x]=cnt;
}
inline bool cmp(int x,int y)
{
return pos[x]<pos[y];
}
inline void dfs(int x,int fa)
{
pos[x]=++tot;
for(int i=1;(1<<i)<=d[x];i++) p[x][i]=p[p[x][i-1]][i-1];
for(int i=head[x];i;i=e[i].next) if (e[i].to!=fa)
{
mn[e[i].to]=min(mn[x],(ll)e[i].v);
d[e[i].to]=d[x]+1;
p[e[i].to][0]=x;
dfs(e[i].to,x);
}
}
inline int lca(int x,int y)
{
if (d[x]<d[y]) swap(x,y);
int t=d[x]-d[y];
for(int i=0;(1<<i)<=t;i++) if (t&(1<<i)) x=p[x][i];
if (x==y) return x;
D(i,20,0) if (p[x][i]!=p[y][i]) x=p[x][i],y=p[y][i];
return p[x][0];
}
inline void dp(int x)
{
f[x]=mn[x];
ll tmp=0;
for(int i=head2[x];i;i=e2[i].next) dp(e2[i].to),tmp+=f[e2[i].to];
if (tmp&&!tag[x]) f[x]=min(f[x],tmp);
head2[x]=0;
}
inline void solve()
{
cnt=0;
k=read();
F(i,1,k) a[i]=read(),tag[a[i]]=true;
sort(a+1,a+k+1,cmp);
top=0;
F(i,1,k)
{
if (top==0){s[++top]=a[i];continue;}
int lc=lca(s[top],a[i]);
while (d[lc]<d[s[top]])
{
if (d[lc]>=d[s[top-1]])
{
add_edge2(lc,s[top]);
if (s[--top]!=lc) s[++top]=lc;
break;
}
add_edge2(s[top-1],s[top]);top--;
}
s[++top]=a[i];
}
while (top>1) add_edge2(s[top-1],s[top]),top--;
dp(s[1]);printf("%lld\n",f[s[1]]);
for (int i=1;i<=k;i++) tag[a[i]]=false;
}
int main()
{
n=read();
F(i,1,n-1)
{
int x=read(),y=read(),z=read();
add_edge(x,y,z);
}
mn[1]=inf;dfs(1,0);
m=read();
while (m--) solve();
return 0;
}