1. 解题思路
这一题核心思路还是一个LCA算法,即最小公共父节点算法。
具体来说的话,对于每一个query给到的节点 u , v u,v u,v,我们都可以先通过LCA算法找到其最小公共父节点 p p p,然后我们可以提前在预处理过程中记录下每一个节点到根节点 0 0 0的距离为 d i d_i di,此时不难发现三个节点 u , v , p u,v,p u,v,p到根节点的距离分别就是 d u , d v , d p d_u, d_v, d_p du,dv,dp,而 u , v u,v u,v两节点之间的距离就是 d u + d v − 2 d p d_u+d_v-2d_p du+dv−2dp,此时,我们不难找到 u v uv uv线段的中点。
此时,我们需要分类讨论一下:
- 如果其中点在线段右侧,即 p v pv pv这一段,那么我们要做的就是找到 v v v的父节点当中第一个距离大于等于 d v − ( d u + d v − 2 d p ) / 2 d_v-(d_u+d_v-2d_p)/2 dv−(du+dv−2dp)/2的节点;
- 如果其中点在线段左侧,即 p u pu pu这一段,那么我们要做的就是找到 v v v的父节点当中第一个距离小于 d v − ( d u + d v − 2 d p ) / 2 d_v-(d_u+d_v-2d_p)/2 dv−(du+dv−2dp)/2的节点;
而这个我们只需要稍微调整一下LCA的倍增过程即可快速实现。
2. 代码实现
给出python代码实现如下:
import math
from collections import deque
from typing import List, Tuple
class Tree:
def __init__(self, n: int, edges: List[Tuple[int, int, int]], root: int = 0):
self.n = n
self.max_log = math.floor(math.log2(n)) + 1 # 最大跳跃步数的对数
self.graph = [[] for _ in range(n)]
self.distances = [0 for _ in range(n)]
self.depth = [-1] * n
self.parent = [[-1] * n for _ in range(self.max_log)] # parent[k][i]: i 的第 2^k 级祖先
# 构建邻接表
for u, v, w in edges:
self.graph[u].append((v, w))
self.graph[v].append((u, w))
# 预处理深度和祖先表
self._bfs(root)
def _bfs(self, root: int):
"""BFS 初始化深度和直接父节点(即 2^0 级祖先)"""
queue = deque([root])
self.depth[root] = 0
self.distances[root] = 0
self.parent[0][root] = -1 # 根节点无父节点
while queue:
u = queue.popleft()
for v, w in self.graph[u]:
if v == self.parent[0][u]:
continue
self.depth[v] = self.depth[u] + 1
self.distances[v] = self.distances[u] + w
self.parent[0][v] = u
queue.append(v)
# 递推计算 2^k 级祖先
for k in range(1, self.max_log):
for i in range(self.n):
if self.parent[k-1][i] != -1:
self.parent[k][i] = self.parent[k-1][self.parent[k-1][i]]
def query(self, u: int, v: int) -> int:
"""查询节点 u 和 v 的最近公共祖先"""
# 确保 u 是深度较大的节点
if self.depth[u] < self.depth[v]:
u, v = v, u
# 将 u 上跳到与 v 同深度
diff = self.depth[u] - self.depth[v]
k = 0
while diff:
if diff & 1:
u = self.parent[k][u]
diff >>= 1
k += 1
if u == v:
return u
# 同步上跳,寻找最近公共祖先
for k in range(self.max_log - 1, -1, -1):
if self.parent[k][u] != self.parent[k][v]:
u = self.parent[k][u]
v = self.parent[k][v]
return self.parent[0][u]
def query_distance(self, u):
return self.distances[u]
def query_closest_parent(self, u: int, d: float):
h = self.max_log-1
while h >= 0:
if self.parent[h][u] != -1 and self.distances[self.parent[h][u]] >= d:
u = self.parent[h][u]
h -= 1
return u
class Solution:
def findMedian(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
tree = Tree(n, edges, 0)
def query(u, v):
p = tree.query(u, v)
du, dv, dp = tree.query_distance(u), tree.query_distance(v), tree.query_distance(p)
d1, d2 = du-dp, dv-dp
d = (d1+d2) / 2
if d <= d2:
return tree.query_closest_parent(v, dv-d)
else:
w = tree.query_closest_parent(u, du-d)
dw = tree.query_distance(w)
return tree.parent[0][w] if du-dw != d else w
return [query(u, v) for u, v in queries]
提交代码评测得到:耗时1292ms,占用内存102.01MB。
413

被折叠的 条评论
为什么被折叠?



