题目大意
ansu=∑i在u子树中∑j在u子树中且i<j(vi xor vj)∗LCP(Si,Sj)
每个点都有点权v和一个字符串S,求ans[]。
做法
可以想到把v拆位做于是现在变成了黑白树。
可以想到dsu on tree,trie上节点记录子树内某个颜色的数量即可。
复杂度很大,但是树剖log较小,这题数据还很迷,可以跑过去。
#include<cstdio>
#include<algorithm>
#include<cstring>
#define fo(i,a,b) for(i=a;i<=b;i++)
using namespace std;
typedef long long ll;
const int maxn=100000+10,maxs=500000+10,maxd=16;
char s[maxs],str[maxs];
ll ans[maxn],sum[maxn],wdc;
int left[maxn],len[maxn],v[maxn],size[maxn],big[maxn];
int h[maxn],go[maxn*2],nxt[maxn*2],sta[maxn];
int g[maxs][27],num[2][maxs];
bool bz[maxn];
int i,j,k,l,t,n,m,tot,top,root;
int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
void insert(int x,int y){
if (y>m) return;
//int c=str[y]-'a';
if (!g[x][str[y]-'a']) g[x][str[y]-'a']=++tot;
insert(g[x][str[y]-'a'],y+1);
/*if (!g[x][c]) g[x][c]=++tot;
insert(g[x][c],y+1);*/
}
void add(int x,int y){
go[++tot]=y;
nxt[tot]=h[x];
h[x]=tot;
}
void travel(int x,int y){
size[x]=1;
int t=h[x],k=0;
while (t){
if (go[t]!=y){
travel(go[t],x);
if (!k||size[go[t]]>size[k]) k=go[t];
size[x]+=size[go[t]];
}
t=nxt[t];
}
big[x]=k;
}
void query(int x,int y,int z,int f,bool p){
if (f==1&&x!=1) wdc+=num[1-p][x];
num[p][x]+=f;
//if (f==1) num[p][x]++;else num[p][x]=0;
if (z==0) return;
query(g[x][s[y]-'a'],y+1,z-1,f,p);
}
void dg(int x,int y,int f,int p){
query(root,left[x],len[x],f,bz[x]);
int t=h[x];
while (t){
if (go[t]!=y&&go[t]!=p) dg(go[t],x,f,p);
t=nxt[t];
}
}
void dfs(int x,int y,bool p){
if (!big[x]){
sum[x]=0;
if (p) dg(x,y,1,0);
return;
}
int t=h[x];
while (t){
if (go[t]!=y&&go[t]!=big[x]) dfs(go[t],x,0);
t=nxt[t];
}
dfs(big[x],x,1);
sum[x]=sum[big[x]];
wdc=0;
dg(x,y,1,big[x]);
sum[x]+=wdc;
if (!p) dg(x,y,-1,0);
}
void write(ll x){
if (!x){
putchar('0');
putchar('\n');
return;
}
top=0;
while (x){
sta[++top]=x%10;
x/=10;
}
while (top) putchar('0'+sta[top--]);
putchar('\n');
}
int main(){
freopen("tree.in","r",stdin);freopen("tree.out","w",stdout);
n=read();
fo(i,1,n) v[i]=read();
root=1;tot=1;
fo(i,1,n){
scanf("%s",str+1);
len[i]=m=strlen(str+1);
left[i]=top+1;
fo(j,1,m) s[++top]=str[j];
insert(root,1);
}
tot=0;
fo(i,1,n-1){
j=read();k=read();
add(j,k);add(k,j);
}
travel(1,0);
fo(i,0,maxd){
fo(j,1,n)
if (v[j]&(1<<i)) bz[j]=1;else bz[j]=0;
fo(j,1,n) sum[j]=0;
dfs(1,0,0);
fo(j,1,n) ans[j]+=(ll)sum[j]*(1<<i);
}
fo(i,1,n) write(ans[i]);
}