<更新提示>
<正文>
幸运数字
Description
A 国共有 n 座城市,这些城市由 n-1 条道路相连,使得任意两座城市可以互达,且路径唯一。每座城市都有一个幸运数字,以纪念碑的形式矗立在这座城市的正中心,作为城市的象征。
一些旅行者希望游览 A 国。旅行者计划乘飞机降落在 x 号城市,沿着 x 号城市到 y 号城市之间那条唯一的路径游览,最终从 y 城市起飞离开 A 国。在经过每一座城市时,游览者就会有机会与这座城市的幸运数字拍照,从而将这份幸运保存到自己身上。然而,幸运是不能简单叠加的,这一点游览者也十分清楚。他们迷信着幸运数字是以异或的方式保留在自己身上的。
例如,游览者拍了 3 张照片,幸运值分别是 5,7,11,那么最终保留在自己身上的幸运值就是 9(5 xor 7 xor 11)。
有些聪明的游览者发现,只要选择性地进行拍照,便能获得更大的幸运值。例如在上述三个幸运值中,只选择 5 和 11 ,可以保留的幸运值为 14 。现在,一些游览者找到了聪明的你,希望你帮他们计算出在他们的行程安排中可以保留的最大幸运值是多少。
Input Format
第一行包含 2 个正整数 n ,q,分别表示城市的数量和旅行者数量。
第二行包含 n 个非负整数,其中第 i 个整数 Gi 表示 i 号城市的幸运值。
随后 n-1 行,每行包含两个正整数 x ,y,表示 x 号城市和 y 号城市之间有一条道路相连。
随后 q 行,每行包含两个正整数 x ,y,表示这名旅行者的旅行计划是从 x 号城市到 y 号城市。N<=20000,Q<=200000,Gi<=2^60
Output Format
输出需要包含 q 行,每行包含 1 个非负整数,表示这名旅行者可以保留的最大幸运值。
Sample Input
4 2
11 5 7 9
1 2
1 3
1 4
2 3
1 4
Sample Output
14
11
Solution
元素的异或最大值考虑使用线性基来求解,那么问题就是如何处理树上路径,转为序列问题。
最粗暴的想法是树链剖分,然后用线段树维护树上路径的线性基,向上合并即可。查询也是树链剖分的方法,时间复杂度 O ( n log 3 n + q log 4 n ) O(n\log^3n+q\log^4n) O(nlog3n+qlog4n),因为线性基的合并是 O ( l o g 2 n ) O(log^2n) O(log2n)的。
考虑到没有修改操作,使用线段树比较浪费,可以直接树上倍增,时间复杂度优化为 O ( ( n + q ) log 3 n ) O((n+q)\log^3n) O((n+q)log3n)。
更好的做法是点分治。我们知道点分治可以统计树上路径问题,那么点分治同样也可以回答树上路径询问。具体的说,我们每次回答经过分治重心的路径所对应的询问,只需在 d f s dfs dfs这棵树的同时求出重心到每一个点路径上的线性基,那么对于每一个询问,只要合并两个线性基即可,时间复杂度优化为 O ( ( n + q ) log 2 n ) O((n+q)\log^2n) O((n+q)log2n)。
参考代码如下:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
const int N = 20020 , M = 200020;
struct LinearBase
{
ll v[65];
LinearBase () { memset( v , 0 , sizeof v ); }
inline void insert(ll x)
{
for (int i=60;i>=0;i--)
if ( x & ( 1LL << i ) )
if ( v[i] ) x ^= v[i];
else return v[i] = x , void();
}
inline ll query(void)
{
ll res = 0;
for (int i=60;i>=0;i--)
res = max( res , res ^ v[i] );
return res;
}
};
typedef LinearBase lBase;
struct edge { int ver,next; } e[N*2];
int n,m,t,Head[N],rec[N]; ll a[N],ans[M];
int size[N],Max[N],flag[N],root,tot;
vector <pii> q[N]; lBase b[N];
inline int read(void)
{
int x = 0 , w = 0; char ch = ' ';
while ( !isdigit(ch) ) w |= ch == '-' , ch = getchar();
while ( isdigit(ch) ) x = x * 10 + ch - 48 , ch = getchar();
return w ? -x : x;
}
inline ll readll(void)
{
ll x = 0 , w = 0; char ch = ' ';
while ( !isdigit(ch) ) w |= ch == '-' , ch = getchar();
while ( isdigit(ch) ) x = x * 10LL + ch - 48 , ch = getchar();
return w ? -x : x;
}
inline void insert(int x,int y) { e[++t] = (edge){y,Head[x]} , Head[x] = t; }
inline lBase merge(lBase P,lBase Q)
{
lBase res;
for (int i=0;i<=60;i++)
{
if ( P.v[i] ) res.insert(P.v[i]);
if ( Q.v[i] ) res.insert(Q.v[i]);
}
return res;
}
inline void input(void)
{
n = read() , m = read();
for (int i=1;i<=n;i++) a[i] = readll();
for (int i=1;i<n;i++)
{
int x = read() , y = read();
insert(x,y) , insert(y,x);
}
for (int i=1;i<=m;i++)
{
int u = read() , v = read();
if ( u == v ) { ans[i] = a[u]; continue; }
q[u].push_back( make_pair(v,i) );
q[v].push_back( make_pair(u,i) );
}
}
inline void dp(int x,int fa)
{
Max[x] = 0 , size[x] = 1;
for (int i=Head[x];i;i=e[i].next)
{
int y = e[i].ver;
if ( y == fa || flag[y] ) continue;
dp( y , x );
size[x] += size[y];
Max[x] = max( Max[x] , size[y] );
}
Max[x] = max( Max[x] , tot - size[x] );
if ( Max[x] < Max[root] ) root = x;
}
inline void dfs1(int x,int fa,int rt)
{
b[x] = b[fa] , b[x].insert(a[x]);
for (int i=0;i<q[x].size();i++)
{
pii t = q[x][i];
if ( !rec[t.first] ) continue;
lBase P = merge( b[t.first] , b[x] );
P.insert( a[rt] ) , ans[t.second] = P.query();
}
for (int i=Head[x];i;i=e[i].next)
{
int y = e[i].ver;
if ( y == fa || flag[y] ) continue;
dfs1( y , x , rt );
}
}
inline void dfs2(int x,int fa,int v)
{
rec[x] = v;
for (int i=Head[x];i;i=e[i].next)
{
int y = e[i].ver;
if ( y == fa || flag[y] ) continue;
dfs2( y , x , v );
}
}
inline void divide(int x)
{
flag[x] = true;
b[x].clear() , rec[x] = true;
for (int i=Head[x];i;i=e[i].next)
{
int y = e[i].ver;
if ( flag[y] ) continue;
dfs1( y , 0 , x );
dfs2( y , 0 , 1 );
}
dfs2( x , 0 , 0 );
for (int i=Head[x];i;i=e[i].next)
{
int y = e[i].ver;
if ( flag[y] ) continue;
tot = size[y] , root = 0;
dp( y , 0 ) , divide(root);
}
}
int main(void)
{
input();
root = 0 , tot = n , Max[0] = 1<<30;
dp( 1 , 0 ) , divide( root );
for (int i=1;i<=m;i++)
printf("%lld\n",ans[i]);
return 0;
}
<后记>