题目大意
给你一棵树,让你求树上所有满足 a i ⨁ a j = a l c a ( i , j ) a_i \bigoplus a_j = a_{lca(i, j)} ai⨁aj=alca(i,j) 的点对 i ⨁ j i \bigoplus j i⨁j 之和
解题思路
我们可以枚举每个点作为
l
c
a
lca
lca 的贡献,这是一个不修改的子树问题,可以用树上启发式合并来做(dsu on tree),复杂度
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)
我们怎么求
i
⨁
j
i \bigoplus j
i⨁j,我们可以用一个桶来记录
a
i
⨁
a
r
t
a_i \bigoplus a_{rt}
ai⨁art,然后将
i
i
i 存入桶中。但是我们每放一个数,首先需要与桶中的数计算一次答案,这样复杂度变为了
O
(
n
2
l
o
g
n
)
O(n^2logn)
O(n2logn)。
我们知道
i
⨁
j
+
i
⨁
k
!
=
i
⨁
(
j
+
k
)
i \bigoplus j + i \bigoplus k != i \bigoplus(j + k)
i⨁j+i⨁k!=i⨁(j+k)
但是我们将数化为二进制,我们每一位来运算是满足结合律的,所以我们将
i
i
i 化为二进制存入桶中,我们统计每一位出现了多少次 0 或 1,这样我们可以每一次按位运算。
我们在dsu的时候,先递归计算子树的答案,然后加上子树的贡献。
Code
#include <bits/stdc++.h>
#define ll long long
#define qc ios::sync_with_stdio(false); cin.tie(0);cout.tie(0)
#define fi first
#define se second
#define PII pair<int, int>
#define PLL pair<ll, ll>
#define pb push_back
using namespace std;
const int MAXN = 1e5 + 7;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll mod = 1e9 + 7;
ll ans;
int son[MAXN], sz[MAXN], a[MAXN];
int cnt[MAXN*15][21][2];
int n;
int head[MAXN];
struct edge{
int to, next;
}e[MAXN << 1];
int tot;
void add(int u, int v){
e[tot].to = v;
e[tot].next = head[u];
head[u] = tot++;
}
void init(){
memset(head, -1, sizeof head);
tot = 0;
}
void dfs(int u, int f){
sz[u] = 1;
int maxx = 0;
for(int i = head[u]; ~i; i = e[i].next){
int v = e[i].to;
if(v == f) continue;
dfs(v, u);
sz[u] += sz[v];
if(sz[v] > maxx){
maxx = sz[v];
son[u] = v;
}
}
}
int flag;
void getAns(int x, int u, int f){
int k = a[x] ^ a[u];
for(int i = 0; i <= 20; i++){
ans += cnt[k][i][!((u >> i) & 1)] * (1ll << i);
}
for(int i = head[u]; ~i; i = e[i].next){
int v = e[i].to;
if(v == flag || v == f) continue;
getAns(x, v, u);
}
}
void update(int x, int u, int f, int val){
for(int i = 0; i <= 20; i++){
cnt[a[u]][i][(u >> i) & 1] += val;
}
for(int i = head[u]; ~i; i = e[i].next){
int v = e[i].to;
if(v == flag || v == f) continue;
update(x, v, u, val);
}
}
void dsu(int u, int f, int keep){
for(int i = head[u]; ~i; i = e[i].next){
int v = e[i].to;
if(v == f || v == son[u]) continue;
dsu(v, u, 0);
}
if(son[u]){
dsu(son[u], u, 1);
flag = son[u];
}
for(int i = head[u]; ~i; i = e[i].next){
int v = e[i].to;
if(v == flag || v == f) continue;
getAns(u, v, u);
update(u, v, u, 1);
}
flag = 0;
for(int i = 0; i <= 20; i++){
cnt[a[u]][i][(u >> i) & 1]++;
}
if(!keep){
update(u, u, f, -1);
}
}
void solve(){
cin >> n;
init();
for (int i = 1; i <= n; ++i){
cin >> a[i];
}
for (int i = 1; i <= n-1; ++i){
int u, v;
cin >> u >> v;
add(u, v);
add(v, u);
}
dfs(1, 0);
dsu(1, 0, 0);
cout << ans << "\n";
}
int main()
{
#ifdef ONLINE_JUDGE
#else
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
#endif
qc;
int T;
// cin >> T;
T = 1;
while(T--){
solve();
}
return 0;
}