题意:给定一棵树,查询时给定两个点,求出两个点的距离。
暴力做肯定超时的。我的做法是采用lca(最近公共祖先)的离线算法,即tarjan算法(据说Tarjan提出了很多算法,可能还有很多tarjan算法),算法里用到了并查集。在输入完所有查询之后,在求出答案。tarjan算法的做法是:一开始vis数组初始化为0,从树根开始递归往下对点进行染色,刚到一个点的时候将vis取为-1,在继续递归;遍历完子节点返回之后vis变为1。在vis变为1之前,检索一下当前节点的所有查询,设查询中的另外一个节点为To,如果vis[To]==0,就continue,因为To还没有处理,不知道它的信息;如果vis[To]==-1,说明To被访问了一次,但是还没有返回到,这意味着To是当前节点的祖先,因此To就是当前节点的最近公共祖先;如果vis[To]==1,说明To已经处理完了,这时候并查集就派上用场了。在递归时,当一个节点处理完返回到父亲那里时,就把父亲变成其所在集合的代表元素。在刚才讨论到vis[To]==1的情况中,可以知道find(To)(即To所在集合的代表元素)就是To和当前节点的最近公共祖先了(这个可以画图演算一下)。在这道题中,我们一开始可以用一个简单的递归算出每个点到根节点的距离dis[i]。那么对于一个查询的两个点fir和sec,它们的距离就是dis[fir]-dis[lca]+dis[sec]-dis[lca],lca是fir和sec的最近公共祖先。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<cmath>
#include<set>
#include<climits>
#include<queue>
#include<vector>
#include<map>
using namespace std;
struct node
{
int to,id;
node(int t,int i)
{
to=t;
id=i;
}
node(){}
};
const int maxn=50005;
vector<node>vec[maxn];
vector<pair<int,int>>query;
int father[maxn],fir[maxn<<1],nxt[maxn<<1],vv[maxn<<1],val[maxn<<1],dis[maxn],ans[75005],e;
int vis[maxn];//0 means it's white,-1 means it's grey, 1 means it's black
int findn(int n)
{
if(n!=father[n]) father[n]=findn(father[n]);
return father[n];
}
void add(int a,int b,int c,int i)
{
vv[e]=b;
val[e]=c;
nxt[e]=fir[a];
fir[a]=e++;
}
void get_height(int sroot,int dist)
{
vis[sroot]=1;
dis[sroot]=dist;
for(int i=fir[sroot];i!=-1;i=nxt[i])
{
int v=vv[i];
if(!vis[v])
{
get_height(v,dist+val[i]);
}
}
}
void dfs(int cur,int fa)
{
vis[cur]=-1;
for(int i=fir[cur];i!=-1;i=nxt[i])
{
int v=vv[i];
if(!vis[v])
{
dfs(v,cur);
father[v]=cur;
}
}
int size=vec[cur].size();
for(int i=0;i<size;i++)
{
node nxt=vec[cur][i];
if(!vis[nxt.to]) continue;
if(-1==vis[nxt.to])
{
ans[nxt.id]=nxt.to;
}
else if(1==vis[nxt.to])
{
ans[nxt.id]=findn(nxt.to);
}
}
vis[cur]=1;
}
int main()
{
#pragma comment(linker, "/STACK:102400000,102400000")//此代码需要扩栈,可能在递归时耗的内存有点大
int n;
while(scanf("%d",&n)!=EOF)
{
for(int i=0;i<=n;i++)
{
father[i]=i;
fir[i]=-1;
vis[i]=0;
vec[i].clear();
}
e=0;//important
int a,b,c;
for(int i=0;i<n-1;i++)
{
scanf("%d%d%d",&a,&b,&c);
add(a,b,c,i);
add(b,a,c,i);
}
get_height(0,0);
int q;
scanf("%d",&q);
for(int i=0;i<q;i++)
{
scanf("%d%d",&a,&b);
vec[a].push_back(node(b,i));
vec[b].push_back(node(a,i));
query.push_back(make_pair<int,int>(a,b));
}
for(int i=0;i<=n;i++) vis[i]=0;
dfs(0,0);
int size=query.size();
for(int i=0;i<size;i++)
{
int fir=query[i].first;
int sec=query[i].second;
int lca=ans[i];
int distance=abs(dis[lca]-dis[fir])+abs(dis[lca]-dis[sec]);
printf("%d\n",distance);
}
}
}