看到最大异或和,想到线性基。
线性基参考链接1:http://blog.youkuaiyun.com/qaq__qaq/article/details/53812883
线性基参考链接2:https://www.cnblogs.com/ljh2000-jump/p/5869991.html
考虑使用倍增算法:
f[u][i]
f
[
u
]
[
i
]
表示以节点
u
u
为出发点往上走步,算上
u
u
总共个点组成的线性基,
也就是把
f[u][i−1]
f
[
u
]
[
i
−
1
]
和
f[fa[u][i−1]][i−1]
f
[
f
a
[
u
]
[
i
−
1
]
]
[
i
−
1
]
两个线性基合并。
fa[u][i]
f
a
[
u
]
[
i
]
表示点
u
u
向上跳步到达的节点。
询问路径的最大异或和时,可以在求LCA的过程中,把路径拆成不超过
O(logn)
O
(
log
n
)
个
f[...][...]
f
[
.
.
.
]
[
.
.
.
]
,把这
O(logn)
O
(
log
n
)
个线性基合并后查询最大异或和就是答案。
设线性基的大小(位数)为
K
K
,那么一次合并的代价是,计算一个
f[...][...]
f
[
.
.
.
]
[
.
.
.
]
的代价也是
O(K2)
O
(
K
2
)
,因此预处理复杂度为
O(nK2logn)
O
(
n
K
2
log
n
)
。
一次询问中需要
O(logn)
O
(
log
n
)
次合并操作,因此总复杂度
O((n+q)K2logn)
O
(
(
n
+
q
)
K
2
log
n
)
。
代码:
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
inline int read() {
int res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
return bo ? ~res + 1 : res;
}
typedef long long ll;
const int N = 2e4 + 5, M = 4e4 + 5, L = 66, LogN = 17;
int n, q, ecnt, nxt[M], adj[N], go[M], dep[N], fa[N][LogN];
ll a[N];
struct cyx {
ll orz[L];
inline void init() {
int i; for (i = 61; i; i--) orz[i] = -1;
}
inline void ins(ll x) {
int i; for (i = 61; i; i--) {
if (!((x >> i - 1) & 1)) continue;
if (orz[i] == -1) return (void) (orz[i] = x);
else x ^= orz[i];
}
}
inline ll max_xor() {
int i; ll ans = 0; for (i = 61; i; i--)
if ((ans ^ orz[i]) > ans) ans ^= orz[i];
return ans;
}
} dalao[N][LogN];
inline cyx mer(cyx a, cyx b) {
int i; for (i = 61; i; i--)
a.ins(b.orz[i]);
return a;
}
void add_edge(int u, int v) {
nxt[++ecnt] = adj[u]; go[adj[u] = ecnt] = v;
nxt[++ecnt] = adj[v]; go[adj[v] = ecnt] = u;
}
void dfs(int u, int fu) {
int i; dep[u] = dep[fa[u][0] = fu] + 1;
dalao[u][0].init(); dalao[u][0].ins(a[u]);
for (i = 0; i <= 14; i++)
fa[u][i + 1] = fa[fa[u][i]][i],
dalao[u][i + 1] = mer(dalao[u][i], dalao[fa[u][i]][i]);
for (int e = adj[u], v; e; e = nxt[e])
if ((v = go[e]) != fu) dfs(v, u);
}
ll lca(int u, int v) {
int i; cyx ans; if (dep[u] < dep[v]) swap(u, v);
ans.init(); for (i = 15; i >= 0; i--) {
if (dep[fa[u][i]] >= dep[v]) ans = mer(ans, dalao[u][i]), u = fa[u][i];
if (u == v) return ans = mer(ans, dalao[u][0]), ans.max_xor();
}
for (i = 15; i >= 0; i--)
if (fa[u][i] != fa[v][i])
ans = mer(ans, dalao[u][i]), ans = mer(ans, dalao[v][i]),
u = fa[u][i], v = fa[v][i];
ans = mer(ans, dalao[u][1]); ans = mer(ans, dalao[v][1]);
return ans.max_xor();
}
int main() {
int i, x, y; n = read(); q = read();
for (i = 1; i <= n; i++) scanf("%lld", &a[i]);
for (i = 1; i < n; i++) x = read(), y = read(), add_edge(x, y);
dfs(1, 0); while (q--)
x = read(), y = read(), printf("%lld\n", lca(x, y));
return 0;
}