这题需要用到线性基的知识,首先先介绍一下线性基(参考博客http://www.cnblogs.com/ljh2000-jump/p/5869991.html)。
- 线性基就是一组数 a1,a2,...,an ,其中 ax 的最高位的1在第 x 位。
- 线性基中没有0。
- 假设集合
A 的一组线性基 B ,则B 中由xor生成的集合 C 和A 所生成的集合相同。
线性基的构造方式:给定集合
A
,对于
以下是这道题的树链剖分+线性基的做法
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
#define N 20003
#define B 61
#define mid (l+r>>1)
#define lc (d<<1)
#define rc (d<<1|1)
typedef long long ll;
vector<int> E[N];
int dep[N], fa[N], top[N], dfn[N], tid[N], son[N], cnt[N], tot;
ll val[N];
void dfs(int u, int f) {
int i, v;
son[u] = u, cnt[u] = 1, fa[u] = f;
for (i = 0;i < E[u].size();i++) {
v = E[u][i];
if (v == f) continue;
dfs(v, u);
cnt[u] += cnt[v];
if (son[u] == u || cnt[v] > cnt[son[u]]) son[u] = v;
}
}
void dfs1(int u, int Top, int d) {
int i, v;
dep[u] = d, dfn[u] = ++tot;
tid[tot] = u, top[u] = Top;
if (dfn[son[u]] == -1) dfs1(son[u], Top, d+1);
for (i = 0;i < E[u].size();i++) {
v = E[u][i];
if (dfn[v] != -1) continue;
dfs1(v, v, d+1);
}
}
struct Basic {
ll a[B];
Basic(){}
Basic(ll p) {
init();
add(p);
}
void init() {
memset(a, 0, sizeof(a));
}
void add(ll p) {
for (int i = B-1;i >= 0 && p;i--) {
if (p&(1LL<<i)) {
if (!a[i]) {
a[i] = p;
break;
}
else p ^= a[i];
}
}
}
Basic operator + (Basic o) const {
Basic re;
memcpy(re.a, a, sizeof(a));
for (int i = 0;i < B;i++) {
re.add(o.a[i]);
}
return re;
}
ll getMax() {
ll re = 0;
for (int i = B-1;i >= 0;i--) {
if (re < (re^a[i])) re ^= a[i];
}
return re;
}
}tr[N<<2];
void build(int d, int l, int r) {
if (l == r) {
tr[d] = Basic(val[tid[l]]);
return;
}
build(lc, l, mid);
build(rc, mid+1, r);
tr[d] = tr[lc]+tr[rc];
}
Basic query(int d, int l, int r, int L, int R) {
if (l == L && r == R) {
return tr[d];
}
if (R <= mid) return query(lc, l, mid, L, R);
else if (L > mid) return query(rc, mid+1, r, L, R);
else return query(lc, l, mid, L, mid)+query(rc, mid+1, r, mid+1, R);
}
ll query(int u, int v) {
Basic re;
re.init();
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
re = re+query(1, 1, tot, dfn[top[u]], dfn[u]);
u = fa[top[u]];
}
if (dfn[u] > dfn[v]) swap(u, v);
re = re+query(1, 1, tot, dfn[u], dfn[v]);
return re.getMax();
}
int in() {
char c = getchar();
int re = 0;
while (c > '9' || c < '0') c = getchar();
while (c >= '0' && c <= '9') {
re = re*10+c-'0';
c = getchar();
}
return re;
}
int main() {
int n, q, u, v, i, j;
while (~scanf("%d%d", &n, &q)) {
for (i = 1;i <= n;i++) {
scanf("%lld", val+i);
E[i].clear();
}
tot = 0;
for (i = 1;i < n;i++) {
u = in(), v = in();//scanf("%d%d", &u, &v);
E[u].push_back(v);
E[v].push_back(u);
}
memset(dfn, -1, sizeof(dfn));
dfs(1, 0);
dfs1(1, 1, 0);
build(1, 1, n);
while (q--) {
u = in(), v = in();//scanf("%d%d", &u, &v);
printf("%lld\n", query(u, v));
}
}
}
接下来是复杂度降了一个log的lca+线性基的做法:
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
#define N 20003
#define B 61
#define mid (l+r>>1)
#define lc (d<<1)
#define rc (d<<1|1)
typedef long long ll;
vector<int> E[N];
int dep[N], fa[N][16];
ll val[N];
struct Basic {
ll a[B];
Basic(){}
Basic(ll p) {
init();
add(p);
}
void init() {
memset(a, 0, sizeof(a));
}
void add(ll p) {
for (int i = B-1;i >= 0 && p;i--) {
if (p&(1LL<<i)) {
if (!a[i]) {
a[i] = p;
break;
}
else p ^= a[i];
}
}
}
void operator += (Basic o) {
for (int i = 0;i < B;i++) {
add(o.a[i]);
}
}
ll getMax() {
ll re = 0;
for (int i = B-1;i >= 0;i--) {
if (re < (re^a[i])) re ^= a[i];
}
return re;
}
}dp[N][16];
void dfs(int u, int f, int d) {
dep[u] = d;
fa[u][0] = f;
dp[u][0] = Basic(val[f]);
int i, v;
for (i = 1;i < 16;i++) {
fa[u][i] = fa[fa[u][i-1]][i-1];
dp[u][i] = dp[u][i-1];
dp[u][i] += dp[fa[u][i-1]][i-1];
}
for (i = 0;i < E[u].size();i++) {
v = E[u][i];
if (v == f) continue;
dfs(v, u, d+1);
}
}
ll query(int u, int v) {
Basic re;
if (dep[u] < dep[v]) swap(u, v);
re = Basic(val[u]);
int i;
for (i = 15;i >= 0;i--) {
if (dep[v] <= dep[fa[u][i]]) {
re += dp[u][i];
u = fa[u][i];
}
}
if (u != v) {
re += Basic(val[v]);
for (i = 15;i >= 0;i--) {
if (fa[v][i] != fa[u][i]) {
re += dp[u][i];
re += dp[v][i];
u = fa[u][i];
v = fa[v][i];
}
}
re += dp[u][0];
}
return re.getMax();
}
int in() {
char c = getchar();
int re = 0;
while (c > '9' || c < '0') c = getchar();
while (c >= '0' && c <= '9') {
re = re*10+c-'0';
c = getchar();
}
return re;
}
int main() {
int n, q, u, v, i, j;
while (~scanf("%d%d", &n, &q)) {
for (i = 1;i <= n;i++) {
scanf("%lld", val+i);
E[i].clear();
}
for (i = 1;i < n;i++) {
u = in(), v = in();//scanf("%d%d", &u, &v);
E[u].push_back(v);
E[v].push_back(u);
}
dfs(1, 1, 0);
while (q--) {
u = in(), v = in();//scanf("%d%d", &u, &v);
printf("%lld\n", query(u, v));
}
}
}