思路很简单,假设是一颗从一号节点出发的树。只要能封锁对方往外走的最长距离就能全收剩余的点。
因此,求出两个起始点间的距离,各自向对方移动,相遇后依次选择最优分支
关键在于数据量太大,相遇后的选择要预处理,否则存在O(nm)肯定超时
借鉴了别人的思路,按奇偶性预处理,然后二分
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#include <vector>
using namespace std;
#define MAXN 200010
#define SEC 17
struct _edge
{
int to, next, val;
}edge[MAXN<<2];
struct mPair
{
int to, val;
mPair(int t = 0, int v = 0):to(t), val(v){}
bool operator < (const mPair &a) const
{
return to < a.to ;
}
};
int head[MAXN], dep[MAXN], dp[MAXN][20], nmEdge, val[MAXN], fath[MAXN];
int vis[MAXN];
vector<mPair> V1[MAXN];
typedef vector<mPair>::iterator ITT;
int mQueue[MAXN];
vector<int> V2[MAXN], fsnum[MAXN], scnum[MAXN];
int n, m;
void addEdge(int u, int v, int val)
{
edge[nmEdge].to = v;
edge[nmEdge].next = head[u];
edge[nmEdge].val = val;
head[u] = nmEdge++;
edge[nmEdge].to = u;
edge[nmEdge].next = head[v];
edge[nmEdge].val = val;
head[v] = nmEdge++;
}
int moveUp(int a, int s)
{
int i;
if (!s)
return a;
for (i = 0; i <= SEC; ++i)
{
if ( (s>>i) & 1)
a = dp[a][i];
}
return a;
}
int getLCA(int a, int b)
{
/// dep[a] >= dep[b]
if (dep[a] < dep[b]) swap(a, b);
int i;
a = moveUp(a, dep[a]-dep[b]);
if ( a != b)
{
for (i = SEC; i>=0; --i)
{
if (dp[a][i] != dp[b][i])
{
a = dp[a][i];
b = dp[b][i];
}
}
return dp[a][0];
}
return a;
}
int getSum(int l, int r, vector<int> & v)
{
if (l > r) return 0;
if (l) return v[r]-v[l-1];
return v[r];
}
ITT lower_bound(ITT __first, ITT __last, mPair __val)
{
int __len = __last-__first, __half;
ITT __middle;
while (__len > 0)
{
__half = __len >> 1;
__middle = __first + __half;
if (*__middle < __val)
{
__first = __middle+1;
__len = __len - __half - 1;
}
else
__len = __half;
}
return __first;
}
int m_search(int l, int r, int a, vector<int> & v)
{
int m, res;
while (l <= r)
{
m = (l+r)>>1;
if (v[m] >= a) res = m, l=m+1;
else r = m-1;
}
return res;
}
int getResult(int a, int b)
{
int ans;
int dis, x, y, sz, lca, d, o;
vector<mPair>::iterator it;
// x = getLCA(a, b); /// a b 公共节点
// d = dep[a]+dep[b]-(dep[x]<<1); /// a b 间距离
if (a == b) /// a b 相同
{
sz = V2[a].size();
return getSum(0, sz-1, fsnum[a]);
}
ans = 0;
lca = getLCA(a, b);
d = dep[a]+dep[b]-(dep[lca]<<1);
if ( d == 1) /// 只需要忽略一个点 a
{
sz = V2[b].size();
/// a 走向 b
it = lower_bound(V1[b].begin(), V1[b].end(), mPair(a,-1));
y = m_search(0, sz-1, (*it).val, V2[b]);
ans += getSum(0, y-1, scnum[b]) + getSum(y+1, sz-1, fsnum[b]) + (*it).val ;
return ans;
}
/// 忽略2个点
dis = (1+d)>>1;
if (dep[a] == dep[b] || dep[a]==dep[b]+1)
{
a = moveUp(a, dis-1);
b = moveUp(b, d-dis-1);
o = dp[a][0];
}
else if (dep[a] > dep[b])
{
o = moveUp(a, dis);
a = moveUp(a, dis-1);
b = dp[o][0];
}
else
{
dis = d - dis;
o = moveUp(b, dis);
b = moveUp(b, dis-1);
a = dp[o][0];
}
sz = V1[o].size();
it = lower_bound(V1[o].begin(), V1[o].end(), mPair(a,0));
x = m_search(0, sz-1, (*it).val,V2[o]);
it = lower_bound(V1[o].begin(), V1[o].end(), mPair(b,0));
y = m_search(0, sz-1, (*it).val,V2[o]);
if (sz < 3)
{
return V2[o][x];
}
if (d & 0x1) /// 计算 b
ans += V2[o][y];
else
ans += V2[o][x];
if ( x == y) x = y-1;
if ( x > y ) swap(x,y);
ans += getSum(0, x-1, fsnum[o]);
ans += getSum(x+1, y-1, scnum[o]);
ans += getSum(y+1, sz-1, fsnum[o]);
if (d & 0x1)
ans = val[1]-ans;
return ans;
}
int now[MAXN];
mPair stc[MAXN];
void dfs()
{
int i, top = 0, u, v;
dep[1] = 0; dp[1][0] = 1; fath[1] = 0;
memset(vis, 0, sizeof vis);
vis[1] = 1;
for (i = 0; i<= n; ++i)
now[i] = head[i], val[i] = 0, dep[i]=0;
stc[top++] = mPair(1, 0);
while (top)
{
u = stc[top-1].to;
int &cc = now[u];
for ( ; cc != -1; cc = edge[cc].next)
{
v = edge[cc].to;
if (vis[v] ) continue;
stc[top++] = mPair(v, edge[cc].val);
fath[v] = dp[v][0] = u;
dep[v] = dep[u] + 1;
vis[v] = 1;
break;
}
if ( cc < 0)
{
u = stc[top-1].to;
val[fath[u]] += val[u] + stc[top-1].val;
--top;
}
}
}
void bfs()
{
int mh = 0, tail = 0, u, v, i, s;
memset( vis, 0, sizeof vis);
vis[1] = 1;
mQueue[tail++] = 1;
while (mh < tail)
{
u = mQueue[mh++];
for (i = head[u]; i!=-1; i = edge[i].next)
{
v = edge[i].to;
if (vis[v]) continue;
vis[v] = 1;
V1[u].push_back(mPair(v, val[v]+edge[i].val));
V1[v].push_back(mPair(u, val[1]-val[v]));
V2[u].push_back(val[v]+edge[i].val);
V2[v].push_back(val[1]-val[v]);
mQueue[tail++] = v;
}
}
vector<int>::iterator it;
for (i = 1; i<= n; ++i)
{
sort(V2[i].begin(), V2[i].end(), greater<int>());
sort(V1[i].begin(), V1[i].end());
s = V2[i].size();
for (it = V2[i].begin(), u = 0; it != V2[i].end(); ++it)
{
v = *it;
if (u & 0x1)
{
fsnum[i].push_back(0);
scnum[i].push_back(v);
}
else
{
fsnum[i].push_back(v);
scnum[i].push_back(0);
}
++u;
}
for (u = 1; u < s; ++u)
{
fsnum[i][u] += fsnum[i][u-1];
scnum[i][u] += scnum[i][u-1];
}
}
}
void solve()
{
int i, j;
dfs();
for (i = 0; i<= n; ++i)
{
V1[i].clear();
V2[i].clear();
fsnum[i].clear();
scnum[i].clear();
}
bfs();
for (j = 1; j<= SEC; ++j)
{
for (i = 1; i<= n; ++i)
{
dp[i][j] = dp[dp[i][j-1]][j-1];
}
}
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("in.txt", "r", stdin);
#endif
int t;
int i, a, j, k;
scanf("%d", &t);
while (t--)
{
scanf("%d%d", &n, &m);
memset(head, -1, 4*(5+n));
nmEdge = 0;
for (a = 1; a< n; ++a)
{
scanf("%d%d%d", &i, &j, &k);
addEdge(i, j, k);
}
solve();
while (m--)
{
scanf("%d%d", &i, &j);
printf("%d\n", getResult(i, j));
}
}
return 0;
}