题目大意:
给出一颗树,有些点被染色过,每条边有一个权值,求这棵树上不经过k个染色点的最长路径(起点终点可以相同)
分析:
非常裸的点分治,套一个树状数组存储经过k个染色点的最短路径,为了避免卡常,我用了一个更新标记数组,不难实现。
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#define SF scanf
#define PF printf
#define MAXN 200010
using namespace std;
int tree[MAXN],n,m,k,cnt,ans,tot;
int col[MAXN],foc[MAXN],sumx[MAXN],sumy[MAXN],up[MAXN];
bool dele,del[MAXN],vis[MAXN];
vector<int> a[MAXN],p[MAXN];
void dp(int x,int fa){
vis[x]=1;
for(int i=0;i<a[x].size();i++)
if(a[x][i]!=fa&&del[a[x][i]]==0){
dp(a[x][i],x);
sumx[x]+=sumx[a[x][i]]+1;
}
}
int dp1(int x,int fa,int sum){
sumy[x]=sum-sumx[x];
for(int i=0;i<a[x].size();i++)
if(a[x][i]!=fa&&del[a[x][i]]==0)
sumy[x]=max(sumy[x],sumx[a[x][i]]+1);
int x1=x;
for(int i=0;i<a[x].size();i++)
if(a[x][i]!=fa&&del[a[x][i]]==0){
int x2=dp1(a[x][i],x,sum);
if(sumy[x2]<sumy[x1])
x1=x2;
}
return x1;
}
void find_foc(){
memset(foc,0,sizeof foc);
memset(sumx,0,sizeof sumx);
memset(sumy,0,sizeof sumy);
memset(vis,0,sizeof vis);
dele=0;
cnt=0;
for(int i=1;i<=n;i++){
if(del[i]==1)
continue;
if(vis[i]==0){
dele=1;
dp(i,0);
foc[++cnt]=dp1(i,0,sumx[i]);
}
}
}
int query(int x){
int res=0;
x++;
while(x){
if(tot==up[x])
res=max(tree[x],res);
x-=x&(-x);
}
return res;
}
void add(int x,int y){
x++;
while(x<=k+1){
if(tot!=up[x]){
tree[x]=y;
up[x]=tot;
}
else
tree[x]=max(tree[x],y);
x+=x&(-x);
}
}
void dfs(int x,int fa,int sum,int tot){
if(tot<=k){
int ans1=query(k-tot)+sum;
ans=max(ans,ans1);
}
for(int i=0;i<a[x].size();i++)
if(a[x][i]!=fa&&del[a[x][i]]==0){
if(col[a[x][i]]==1)
dfs(a[x][i],x,sum+p[x][i],tot+1);
else
dfs(a[x][i],x,sum+p[x][i],tot);
}
}
void update(int x,int fa,int sum,int tot){
if(tot<=k)
add(tot,sum);
for(int i=0;i<a[x].size();i++)
if(a[x][i]!=fa&&del[a[x][i]]==0){
if(col[a[x][i]]==1)
update(a[x][i],x,sum+p[x][i],tot+1);
else
update(a[x][i],x,sum+p[x][i],tot);
}
}
int u,v,val;
int main(){
SF("%d%d%d",&n,&k,&m);
for(int i=1;i<=m;i++){
SF("%d",&u);
col[u]=1;
}
for(int i=1;i<n;i++){
SF("%d %d %d",&u,&v,&val);
a[u].push_back(v);
a[v].push_back(u);
p[u].push_back(val);
p[v].push_back(val);
}
find_foc();
while(dele==1){
for(int i=1;i<=cnt;i++){
int x=foc[i];
tot++;
for(int i=0;i<a[x].size();i++)
if(del[a[x][i]]==0){
dfs(a[x][i],x,p[x][i],col[x]+col[a[x][i]]);
update(a[x][i],x,p[x][i],col[a[x][i]]);
}
del[x]=1;
}
find_foc();
}
PF("%d",ans);
}