【题目链接】
【思路要点】
- 考虑枚举最终路径的一端 x x x ,那么路径的另一端 y y y 应当在所有过 x x x 的路径的并上,也即这些路径两端所有点形成的虚树上,因此,我们需要知道这个虚树的大小。
- 注意到虚树大小即为将点按照 d f s dfs dfs 序环形排列后相邻的点在原树上距离和的一半,可以用线段树维护 d f s dfs dfs 序,记录区间最左、最右侧的点,以及区间相邻点的距离和,即可支持加点和删点。
- 用线段树合并即可计算答案,为保证复杂度,需要用 S T ST ST 表预处理后 O ( 1 ) O(1) O(1) 询问 L c a Lca Lca 。
- 时间复杂度 O ( N L o g N + M L o g N ) O(NLogN+MLogN) O(NLogN+MLogN) 。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 2e5 + 5; const int MAXP = 8e6 + 5; typedef long long ll; typedef long double ld; typedef unsigned long long ull; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } vector <int> a[MAXN]; int tot, seq[MAXN], home[MAXN], Real[MAXN]; int n, m, timer, depth[MAXN], father[MAXN], dfn[MAXN]; namespace rmq { const int MAXN = 2e5 + 5; const int MAXLOG = 18; int Min[MAXN][MAXLOG], Log[MAXN]; int comb(int x, int y) { return (depth[x] < depth[y]) ? x : y; } int queryMin(int l, int r) { int len = r - l + 1, tmp = Log[len]; return comb(Min[l][tmp], Min[r - (1 << tmp) + 1][tmp]); } void init(int *a, int n) { for (int i = 1; i <= n; i++) { Min[i][0] = a[i]; Log[i] = Log[i - 1]; if ((1 << (Log[i] + 1)) <= i) Log[i]++; } for (int t = 1; t < MAXLOG; t++) for (int i = 1, j = (1 << (t - 1)) + 1; j <= n; i++, j++) Min[i][t] = comb(Min[i][t - 1], Min[j][t - 1]); } } void work(int pos, int fa) { dfn[pos] = ++timer; Real[timer] = pos; seq[++tot] = pos; home[pos] = tot; father[pos] = fa; depth[pos] = depth[fa] + 1; for (auto x : a[pos]) if (x != fa) { work(x, pos); seq[++tot] = pos; } } inline int lca(int x, int y) { if (home[x] > home[y]) swap(x, y); return rmq :: queryMin(home[x], home[y]); } inline int dist(int x, int y) { x = Real[x], y = Real[y]; return depth[x] + depth[y] - 2 * depth[lca(x, y)]; } struct SegmentTree { struct Node { int lc, rc, leaf; int sum, cnt, lpos, rpos; } a[MAXP]; int size, n; void init(int x) { n = x; size = 0; } void update(int root) { a[root].sum = a[a[root].lc].sum + a[a[root].rc].sum; if (a[a[root].lc].rpos && a[a[root].rc].lpos) a[root].sum += dist(a[a[root].lc].rpos, a[a[root].rc].lpos); if (a[a[root].lc].lpos) a[root].lpos = a[a[root].lc].lpos; else a[root].lpos = a[a[root].rc].lpos; if (a[a[root].rc].rpos) a[root].rpos = a[a[root].rc].rpos; else a[root].rpos = a[a[root].lc].rpos; } void modify(int &root, int l, int r, int pos, int val) { if (root == 0) root = ++size; if (l == r) { a[root].cnt += val; a[root].leaf = l; if (a[root].cnt) a[root].lpos = a[root].rpos = pos; else a[root].lpos = a[root].rpos = 0; return; } int mid = (l + r) / 2; if (mid >= pos) modify(a[root].lc, l, mid, pos, val); else modify(a[root].rc, mid + 1, r, pos, val); update(root); } void modify(int &root, int pos, int val) { return modify(root, 1, n, pos, val); } int merge(int x, int y) { if (x == 0 || y == 0) return x + y; if (a[x].leaf) { a[x].cnt += a[y].cnt; if (a[x].cnt) a[x].lpos = a[x].rpos = a[x].leaf; else a[x].lpos = a[x].rpos = 0; return x; } a[x].lc = merge(a[x].lc, a[y].lc); a[x].rc = merge(a[x].rc, a[y].rc); update(x); return x; } void join(int &to, int from) { to = merge(to, from); } } ST; int root[MAXN]; ll ans; void getans(int pos, int fa) { for (auto x : a[pos]) if (x != fa) { getans(x, pos); ST.join(root[pos], root[x]); } if (ST.a[root[pos]].lpos != ST.a[root[pos]].rpos) { int now = ST.a[root[pos]].sum; now += dist(ST.a[root[pos]].lpos, ST.a[root[pos]].rpos); ans += now / 2; } } int main() { read(n), read(m); for (int i = 1; i <= n - 1; i++) { int x, y; read(x), read(y); a[x].push_back(y); a[y].push_back(x); } work(1, 0); ST.init(n); rmq :: init(seq, tot); for (int i = 1; i <= m; i++) { int x, y; read(x), read(y); int z = father[lca(x, y)]; ST.modify(root[x], dfn[x], 1); ST.modify(root[x], dfn[y], 1); ST.modify(root[y], dfn[x], 1); ST.modify(root[y], dfn[y], 1); ST.modify(root[z], dfn[x], -2); ST.modify(root[z], dfn[y], -2); } getans(1, 0); writeln(ans / 2); return 0; }