题意
给一颗
n
n
n个节点的树,节点
u
u
u有权值
a
[
u
]
(
1
≤
a
[
u
]
≤
1
e
6
)
a[u](1 \le a[u] \le 1e6)
a[u](1≤a[u]≤1e6),计算下列表达式:
∑
i
=
1
n
∑
j
=
i
+
1
n
[
a
[
i
]
x
o
r
a
[
j
]
=
=
a
[
l
c
a
(
i
,
j
)
]
]
(
i
x
o
r
j
)
\sum_{i=1}^{n}\sum_{j=i+1}^{n}[a[i]\ xor\ a[j]==a[lca(i,j)]](i\ xor\ j)
i=1∑nj=i+1∑n[a[i] xor a[j]==a[lca(i,j)]](i xor j)
思路
由于
a
[
u
]
≠
0
a[u] \ne 0
a[u]=0,于是不存在
i
,
j
,
l
c
a
(
i
,
j
)
i,j,lca(i,j)
i,j,lca(i,j)在一条链上的情况,于是可以考虑枚举以
l
c
a
(
i
,
j
)
lca(i,j)
lca(i,j)为根的子树然后计算贡献。
由于
a
[
i
]
x
o
r
a
[
j
]
=
=
a
[
l
c
a
(
i
,
j
)
]
a[i]\ xor\ a[j] == a[lca(i,j)]
a[i] xor a[j]==a[lca(i,j)]等价于
a
[
i
]
=
=
a
[
j
]
x
o
r
a
[
l
c
a
(
i
,
j
)
]
a[i] == a[j]\ xor\ a[lca(i,j)]
a[i]==a[j] xor a[lca(i,j)],现在考虑某颗以
u
u
u为根的子树对答案的贡献如何计算。一个容易想到的思路是,按照顺序遍历以
u
u
u的孩子
v
v
v为根的子树,对于
v
v
v子树内的某一个节点
j
j
j,令
x
=
a
[
j
]
x
o
r
a
[
u
]
x=a[j]\ xor\ a[u]
x=a[j] xor a[u],从已遍历的孩子表示的子树中,找到所有权值为
x
x
x的节点
i
i
i,将答案加上
i
x
o
r
j
i\ xor\ j
i xor j。这样算的复杂度肯定是爆炸的。
考虑按位去算,对于满足
a
[
i
]
=
=
a
[
j
]
x
o
r
a
[
l
c
a
(
i
,
j
)
]
a[i] == a[j]\ xor\ a[lca(i,j)]
a[i]==a[j] xor a[lca(i,j)]的下标
i
i
i的某一位
b
i
t
bit
bit,只有
i
i
i与
j
j
j这一位不同才对答案有贡献,于是我们可以记录
c
n
t
[
x
]
[
b
i
t
]
[
k
]
cnt[x][bit][k]
cnt[x][bit][k]表示,当前值为
x
x
x的下标中,第
b
i
t
bit
bit位为
k
k
k的下标个数。
接下来就是树上启发式合并的经典操作了,首先计算所有轻儿子对答案与
c
n
t
cnt
cnt数组的贡献,然后把轻儿子对
c
n
t
cnt
cnt数组的贡献去掉,计算重儿子对答案与
c
n
t
cnt
cnt数组的贡献即可。
代码
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <set>
#include <vector>
#include <map>
#include <queue>
#include <cmath>
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int N = 100005;
int n;
int a[N];
struct edge
{
int to, next;
}e[N << 1];
int head[N], tot;
int cnt[1 << 20][20][2];
int siz[N], son[N];
bool vis[N];
LL ans;
void add(int u, int v)
{
e[++ tot] = {v, head[u]};
head[u] = tot;
}
void dfs(int u, int ff)
{
siz[u] = 1;
for (int i = head[u]; i; i = e[i].next)
{
int v = e[i].to;
if (v == ff) continue;
dfs(v, u);
siz[u] += siz[v];
if (!son[u] || siz[v] > siz[son[u]]) son[u] = v;
}
}
void modify(int u, int ff, int val)
{
for (int i = 0; i < 20; i ++ )
cnt[a[u]][i][u >> i & 1] += val;
for (int i = head[u]; i; i = e[i].next)
{
int v = e[i].to;
if (v == ff) continue;
modify(v, u, val);
}
}
void updateAns(int u, int ff, int lca)
{
int x = a[u] ^ a[lca];
for (int i = 0; i < 20; i ++ )
ans += (1LL << i) * cnt[x][i][u >> i & 1 ^ 1];
for (int i = head[u]; i; i = e[i].next)
{
int v = e[i].to;
if (v == ff) continue;
updateAns(v, u, lca);
}
}
void dfs2(int u, int ff, int keep)
{
for (int i = head[u]; i; i = e[i].next)
{
int v = e[i].to;
if (v == ff || son[u] == v) continue;
dfs2(v, u, 0);
}
if (son[u]) dfs2(son[u], u, 1), vis[son[u]] = true;
for (int i = head[u]; i; i = e[i].next)
{
int v = e[i].to;
if (v == ff || vis[v]) continue;
updateAns(v, u, u);
modify(v, u, 1);
}
for (int i = 0; i < 20; i ++ )
cnt[a[u]][i][u >> i & 1] ++;
vis[son[u]] = false;
if (!keep)
modify(u, ff, -1);
}
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i ++ )
scanf("%d", &a[i]);
for (int i = 1; i < n; i ++ )
{
int u, v;
scanf("%d%d", &u, &v);
add(u, v);
add(v, u);
}
dfs(1, 0);
dfs2(1, 0, 0);
printf("%lld\n", ans);
return 0;
}