题目描述
小B最近正在玩一个寻宝游戏,这个游戏的地图中有N个村庄和N-1条道路,并且任何两个村庄之间有且仅有一条路径可达。游戏开始时,玩家可以任意选择一个村庄,瞬间转移到这个村庄,然后可以任意在地图的道路上行走,若走到某个村庄中有宝物,则视为找到该村庄内的宝物,直到找到所有宝物并返回到最初转移到的村庄为止。
小B希望评测一下这个游戏的难度,因此他需要知道玩家找到所有宝物需要行走的最短路程。但是这个游戏中宝物经常变化,有时某个村庄中会突然出现宝物,有时某个村庄内的宝物会突然消失,因此小B需要不断地更新数据,但是小B太懒了,不愿意自己计算,因此他向你求助。为了简化问题,我们认为最开始时所有村庄内均没有宝物
输入输出格式
输入格式:
第一行,两个整数N、M,其中M为宝物的变动次数。接下来的N-1行,每行三个整数x、y、z,表示村庄x、y之间有一条长度为z的道路。接下来的M行,每行一个整数t,表示一个宝物变动的操作。若该操作前村庄t内没有宝物,则操作后村庄内有宝物;若该操作前村庄t内有宝物,则操作后村庄内没有宝物。
输出格式:
M行,每行一个整数,其中第i行的整数表示第i次操作之后玩家找到所有宝物需要行走的最短路程。若只有一个村庄内有宝物,或者所有村庄内都没有宝物,则输出0。
输入输出样例
输入样例#1:
4 5
1 2 30
2 3 50
2 4 60
2
3
4
2
1
输出样例#1:
0
100
220
220
280
说明
1<=N<=100000
1<=M<=100000
对于全部的数据,1<=z<=10^9
分析:
显然最短路径可以通过按所有点的dfs序进行排序后,第
i
i
i个点与第
i
+
1
i+1
i+1个点的距离的和。当然还要加上最后一个点到起点的距离。
对于插入关键点操作,可以维护一个set来维护关键点,修改时减去前一个点到后一个的距离,加上前一个点到当前点的距离与当前点到后一个点的距离。
代码:
// luogu-judger-enable-o2
#include <iostream>
#include <cstdio>
#include <cmath>
#include <set>
#define LL long long
const int maxn=1e5+7;
using namespace std;
int n,m,x,y,w,cnt;
int ls[maxn],f[maxn][20],dep[maxn],dfn[maxn],id[maxn],vis[maxn];
LL dis[maxn],ans;
struct edge{
int y,w,next;
}g[maxn*2];
set <int> s;
void add(int x,int y,int w)
{
g[++cnt]=(edge){y,w,ls[x]};
ls[x]=cnt;
}
void dfs(int x,int fa)
{
dfn[x]=++cnt;
id[cnt]=x;
f[x][0]=fa;
dep[x]=dep[fa]+1;
for (int i=ls[x];i>0;i=g[i].next)
{
int y=g[i].y;
if (y==fa) continue;
dis[y]=dis[x]+(LL)g[i].w;
dfs(y,x);
}
}
int getlca(int x,int y)
{
if (dep[x]>dep[y]) swap(x,y);
int d=dep[y]-dep[x],k=19,t=1<<k;
while (d)
{
if (d>=t) d-=t,y=f[y][k];
t/=2,k--;
}
if (x==y) return x;
k=19;
while (k>=0)
{
if (f[x][k]!=f[y][k])
{
x=f[x][k];
y=f[y][k];
}
k--;
}
return f[x][0];
}
LL getdis(int x,int y)
{
int d=getlca(x,y);
return dis[x]+dis[y]-2*dis[d];
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<n;i++)
{
scanf("%d%d%d",&x,&y,&w);
add(x,y,w);
add(y,x,w);
}
cnt=0;
dfs(1,0);
for (int j=1;j<20;j++)
{
for (int i=1;i<=n;i++) f[i][j]=f[f[i][j-1]][j-1];
}
ans=0;
for (int i=1;i<=m;i++)
{
scanf("%d",&x);
if (!vis[x]) s.insert(dfn[x]);
int last,next;
set <int> ::iterator it,it0;
it=s.lower_bound(dfn[x]);
if (it!=s.begin())
{
it--;
last=id[*it];
it++;
}
else
{
it0=s.end();
it0--;
last=id[*it0];
}
it++;
if (it!=s.end()) next=id[*it];
else
{
it0=s.begin();
next=id[*it0];
}
it--;
if (!vis[x])
{
ans-=getdis(last,next);
ans+=getdis(last,x);
ans+=getdis(x,next);
}
else
{
ans+=getdis(last,next);
ans-=getdis(last,x);
ans-=getdis(x,next);
s.erase(it);
}
vis[x]^=1;
printf("%lld\n",ans);
}
}