题目大意
给定一棵有n个节点的树,每个点有黑白两色之一。
Alice和Bob轮流操作,每次先手选择一个白点,将其到根路径上所有点染黑。一开始Alice是先手,两人都选最优策略,不能操作者输。
要求输出Alice是否能赢。如果能赢,那么第一步可以选择哪些节点。
初始时不是所有点都为白点,每个点的颜色会给定。
题目分析
我们先将黑点去掉,白点的父亲设置为它到根路径上第一个白点,重建出一个森林,显然一个子树的胜负状况与其它点无关,且能够通过根节点各儿子节点的子树的胜负状况决定。
那么我们就可以使用sg函数来求解了。设sg(x)表示以x为根的子树的估价函数。
考虑如何递推估价函数。显然,删除
那么sg(x)=mex{s(y)|y∈subtree(x)}。
当然我们肯定不能每次计算都枚举子树,考虑是否存在递推关系。
显然某一棵子树的s(y)集合,到了计算其根节点父亲子树时,都异或上了根节点其它儿子子树的估价函数异或和。
那么我们可以使用字典树保存集合,需要维护的操作有整棵树异或一个数(打标记下传即可),插入一个数,求mex(记录每棵子树是否已满,如果0边没有满,就沿着
最后对森林所有树做Nim和就知道先手胜负了,设结果为nim。
考虑怎么找第一步可行操作。根据sg函数的定义,我们只需在森林每一棵树中找到所有估价函数为nim xor sg(x)的所有路径起点即可,这个东西可以使用模拟链表,在插入时顺便插入,在合并时也合并好。
别忘了处理子树x时要将删除其本身的这种转移的估价函数插入字典树,其值为所有儿子子树估价函数异或和。
时间复杂度
代码实现
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cctype>
#include <cmath>
using namespace std;
const int N=100500;
const int M=N<<1;
const int L=31;
const int S=N*L;
int read()
{
int x=0,f=1;
char ch=getchar();
while (!isdigit(ch))
{
if (ch=='-')
f=-1;
ch=getchar();
}
while (isdigit(ch))
{
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
int buf[20];
void write(int x)
{
if (x<0)
putchar('-'),x=-x;
buf[0]=0;
while (x)
{
buf[++buf[0]]=x%10;
x/=10;
}
while (buf[0])
putchar('0'+buf[buf[0]--]);
}
int power[L];
int ans[N];
struct Trie
{
int tag[S],son[S][2],head[S],tail[S],id[S];
int v[N],next[N],root[N];
bool full[S];
int tot,cnt;
void X(int rt,int edit)
{
if (id[rt]<0||!rt)
return;
tag[rt]^=edit;
if ((edit>>id[rt])&1)
swap(son[rt][0],son[rt][1]);
}
void clear(int rt)
{
if (!rt)
return;
if (tag[rt])
{
X(son[rt][0],tag[rt]),X(son[rt][1],tag[rt]);
tag[rt]=0;
}
}
void update(int x)
{
full[x]=full[son[x][0]]&&full[son[x][1]];
}
int newnode()
{
tag[++tot]=0;
son[tot][0]=son[tot][1]=0;
return tot;
}
void init()
{
tot=0;
}
int merge(int rt1,int rt2)
{
clear(rt1),clear(rt2);
if (!rt1||!rt2)
return rt1+rt2;
if (rt1==rt2)
return rt1;
next[tail[rt1]]=head[rt2];
if (tail[rt2])
tail[rt1]=tail[rt2];
son[rt1][0]=merge(son[rt1][0],son[rt2][0]);
son[rt1][1]=merge(son[rt1][1],son[rt2][1]);
if (id[rt1]!=-1)
update(rt1);
return rt1;
}
int insert(int dig,int rt,int x,int y)
{
if (!rt)
rt=newnode(),id[rt]=dig;
clear(rt);
if (dig==-1)
{
full[rt]=true;
v[++cnt]=y;
next[cnt]=head[rt];
if (!head[rt])
tail[rt]=cnt;
head[rt]=cnt;
return rt;
}
int w=(x>>dig)&1;
son[rt][w]=insert(dig-1,son[rt][w],x,y);
update(rt);
return rt;
}
int mex(int dig,int rt)
{
clear(rt);
if (dig==-1)
return 0;
if (full[son[rt][0]])
return mex(dig-1,son[rt][1])+power[dig];
else
return mex(dig-1,son[rt][0]);
}
void search(int rt,int x)
{
for (int i=L-1;i>=0;i--)
{
clear(rt);
int w=(x>>i)&1;
rt=son[rt][w];
if (!rt)
return;
}
int i=head[rt];
while (i)
{
ans[++ans[0]]=v[i];
i=next[i];
}
}
}t;
int sg[N],fa[N],rts[N],last[N];
int next[M],tov[M];
bool color[N];
int n,tot;
void insert(int x,int y)
{
tov[++tot]=y;
next[tot]=last[x];
last[x]=tot;
}
void dfs(int x)
{
int i=last[x],y,xors=0;
while (i)
{
y=tov[i];
dfs(y);
xors^=sg[y];
i=next[i];
}
i=last[x];
while (i)
{
y=tov[i];
t.X(t.root[y],xors^sg[y]);
i=next[i];
}
i=last[x];
int R=0;
while (i)
{
y=tov[i];
R=t.merge(R,t.root[y]);
i=next[i];
}
t.root[x]=t.insert(L-1,R,xors,x);
sg[x]=t.mex(L-1,t.root[x]);
}
void build(int x,int fr,int z)
{
fa[x]=z;
int i=last[x],y;
while (i)
{
y=tov[i];
if (y!=fr)
if (!color[x])
build(y,x,x);
else
build(y,x,z);
i=next[i];
}
}
int main()
{
power[0]=1;
for (int i=1;i<L;i++)
power[i]=power[i-1]<<1;
freopen("dierti.in","r",stdin);
freopen("dierti.out","w",stdout);
n=read();
for (int i=1;i<=n;i++)
color[i]=read();
for (int i=1,x,y;i<n;i++)
{
x=read(),y=read();
insert(x,y),insert(y,x);
}
build(1,0,0);
memset(last,0,sizeof last);
memset(tov,0,sizeof tov);
memset(next,0,sizeof next);
tot=0;
for (int i=1;i<=n;i++)
if (fa[i]&&!color[i])
insert(fa[i],i);
for (int i=1;i<=n;i++)
if (!fa[i]&&!color[i])
rts[++rts[0]]=i;
t.init();
for (int i=1;i<=rts[0];i++)
dfs(rts[i]);
int xors=0;
for (int i=1;i<=rts[0];i++)
xors^=sg[rts[i]];
if (xors)
{
for (int i=1;i<=n;i++)
t.search(t.root[rts[i]],xors^sg[rts[i]]);
sort(ans+1,ans+1+ans[0]);
for (int i=1;i<=ans[0];i++)
write(ans[i]),putchar('\n');
}
else
write(-1);
fclose(stdin);
fclose(stdout);
return 0;
}