分析
这题的难点是如何记录树上的每个点从起点到这个点的某个字符串的出现次数,而这样的空间是巨大的,所以我们可以联想到主席树这种可持久化的数据结构,所以我们也可以根据字符串建立类似的主席树——字典树。
我们可以建一棵字典树,对于每个询问(x,y),字符s。我们先算出从根到x中的s出现的次数dat1,然后再计算起点到y中的s出现的次数dat2,再算从起点到lca(x,y)的s出现的次数dat3。明显的,答案就是dat1+dat2-dat3*2.
而我们考虑如何建一棵字典树。
因为我们遍历的是一棵树,所以遍历是有顺序的,对于我们遍历到的一个点x,我们便可以从father(x),中去更新x节点的字典树——在father(x)节点的字典树上插入一个字符串s[x]即可,具体的:
int add(int x,int r){
int temp,y;temp=y=++cnt;
for(int i=0;i<=strlen(sr[r])-1;i++){
for(int j=0;j<=25;j++)tri[y][j]=tri[x][j];
x=tri[x][sr[r][i]-97];
y=tri[y][sr[r][i]-97]=++cnt;//新建一个节点
sum[y]=sum[x]+1;//累加
}
return temp;//返回这棵字典树的根节点
}
void dfs(int x,int y){
for(int p=las[x];p;p=nex[p]){
if (b[p]!=y){
root[b[p]]=root[x];
root[b[p]]=add(root[x],(p+1)/2);//建树
g[b[p]][0]=x;//预处理rmq数组
d[b[p]]=d[x]+1;//深度
dfs(b[p],x);
}
}
}
字典树相对于主席树,没有主席树那么复杂,也不用递归实现,因为在字典树里面的每个节点都一定有26个点(”a”..”z”),所以直接循环处理便可以了。
所以这题也就加上一个lca便可以解决了。
代码
#include<iostream>
#include<cmath>
#include<cstring>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
using namespace std;
const int N=100005;
char sr[N][11],s[11];
int num,n,b[N*2],las[N],nex[N*2],x,y,g[N][18],cnt,tri[N*10][27],sum[N*10],root[N],d[N],t;
void insert(int x,int y){
b[++num]=y;nex[num]=las[x];las[x]=num;
}
int add(int x,int r){
int temp,y;temp=y=++cnt;
for(int i=0;i<=strlen(sr[r])-1;i++){
for(int j=0;j<=25;j++)tri[y][j]=tri[x][j];
x=tri[x][sr[r][i]-97];
y=tri[y][sr[r][i]-97]=++cnt;
sum[y]=sum[x]+1;
}
return temp;
}
void dfs(int x,int y){
for(int p=las[x];p;p=nex[p]){
if (b[p]!=y){
root[b[p]]=root[x];
root[b[p]]=add(root[x],(p+1)/2);
g[b[p]][0]=x;
d[b[p]]=d[x]+1;
dfs(b[p],x);
}
}
}
int lca(int x,int y){
if (d[x]<d[y]) swap(x,y);
for(int k=trunc(log2(d[x]-d[y]+1));k>=0;k--)
if (d[g[x][k]]>d[y]) x=g[x][k];
if (d[x]!=d[y]) x=g[x][0];
for(int k=trunc(log2(d[x]));k>=0;k--) if (g[x][k]!=g[y][k]) x=g[x][k],y=g[y][k];
if (x==y) return x;else return g[x][0];
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n-1;i++){
scanf("%d %d %s",&x,&y,&sr[i]);
insert(x,y);
insert(y,x);
}
d[1]=1;
dfs(1,0);
for(int j=1;j<=17;j++)
for(int i=1;i<=n;i++)
g[i][j]=g[g[i][j-1]][j-1];
scanf("%d",&t);cnt=1;
for(int T=1;T<=t;T++){
scanf("%d %d %s",&x,&y,&s);
int z=lca(x,y);
int p=root[x],ans=0;
for(int i=0;i<=strlen(s)-1;i++) p=tri[p][s[i]-97];
ans+=sum[p];
p=root[y];
for(int i=0;i<=strlen(s)-1;i++) p=tri[p][s[i]-97];
ans+=sum[p];
p=root[z];
for(int i=0;i<=strlen(s)-1;i++) p=tri[p][s[i]-97];
ans-=sum[p]*2;
printf("%d\n",ans);
}
}