该题是一道比较基础的LCA(最近公共祖先),也就是快速求出树上任意两个点的最近公共祖先, 然后顺便维护边权值(每个结点到root的距离),就可以快速求出任意两个结点的距离了。
细节参见代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<string>
#include<vector>
#include<stack>
using namespace std;
typedef long long ll;
const double eps = 1e-6;
const int INF = 1000000000;
const int maxn = 40000+5;
const int maxq = 200+5;
int T,n,u,v,k,m;
ll dist[maxn];
int f[maxn],answer[maxq], h[maxn], tt, q, root;
int _find(int x) {
if(f[x] == -1) return x;
return f[x] = _find(f[x]);
}
void bing(int u, int v) {
int t1 = _find(u);
int t2 = _find(v);
if(t1 != t2) f[t1] = t2;
}
bool vis[maxn], flag[maxn];
int ancestor[maxn];
struct Edge {
ll to, next, dist;
}edge[maxn*2];
int head[maxn], tot;
void addedge(int u, int v, int dist) {
edge[tot].to = v;
edge[tot].dist = dist;
edge[tot].next = head[u];
head[u] = tot++;
}
struct Query {
int q, next, index;
}query[maxq*2];
void add_query(int u, int v, int index) {
query[tt].q = v;
query[tt].next = h[u];
query[tt].index = index;
h[u] = tt++;
query[tt].q = u;
query[tt].next = h[v];
query[tt].index = index;
h[v] = tt++;
}
void init() {
tot = 0;
memset(flag, false, sizeof(flag));
memset(head, -1, sizeof(head));
memset(dist, 0, sizeof(dist));
tt = 0;
memset(h, -1, sizeof(h));
memset(vis, false, sizeof(vis));
memset(f, -1, sizeof(f));
memset(ancestor, 0, sizeof(ancestor));
}
void LCA(int u) {
ancestor[u] = u;
vis[u] = true;
for(int i = head[u]; i != -1; i = edge[i].next) {
int v = edge[i].to;
if(vis[v]) continue;
dist[v] = dist[u] + edge[i].dist;
LCA(v);
bing(u, v);
ancestor[_find(u)] = u;
}
for(int i = h[u]; i != -1; i = query[i].next) {
int v = query[i].q;
if(vis[v]) {
answer[query[i].index] = ancestor[_find(v)];
}
}
}
struct node{
int u, v;
}a[maxq];
int main() {
scanf("%d",&T);
while(T--) {
scanf("%d%d",&n,&m);
init();
for(int i=0;i<n-1;i++) {
scanf("%d%d%d",&u,&v,&k);
flag[v] = true;
addedge(u, v, k);
addedge(v, u, k);
}
for(int i=0;i<m;i++) {
scanf("%d%d",&a[i].u,&a[i].v);
add_query(a[i].u,a[i].v,i);
}
for(int i=1;i<=n;i++) {
if(!flag[i]) {
root = i; break;
}
}
LCA(root);
for(int i=0;i<m;i++) {
int lca = answer[i], u = a[i].u , v = a[i].v;
printf("%I64d\n",dist[u]+dist[v]-2*dist[lca]);
}
}
return 0;
}