有一个长度为n的仅包含小写字母的字符串S,下标范围为[1,n].
现在有若干组询问,对于每一个询问,我们给出若干个后缀(以其在S中出现的起始位置来表示),求这些后缀两两之间的LCP(LongestCommonPrefix)的长度之和.一对后缀之间的LCP长度仅统计一遍.
怕忘记虚树怎么写于是挂个板子题。
LCP长度相当于后缀树上LCA深度,对于每个点统计有多少点对以它为LCA即可。按DFS序排序后反向扫一遍就相当于DFS。
#include<cstdio>
#include<algorithm>
#include<cstring>
#define gm 500005
using namespace std;
typedef long long ll;
struct node
{
node *s[26],*pa;
int dep;
int no;
bool used;
node(int dep):s(),pa(),dep(dep),no(),used(){}
void copy(const node &x)
{
memcpy(s,x.s,sizeof s);
pa=x.pa;
}
};
struct e
{
int t;
e *n;
e(int t,e *n):t(t),n(n){}
}*f[gm<<1];
int dep[gm<<1],pos[gm],dfn[gm<<1],dis[gm<<1];
int son[gm<<1],fat[gm<<1],top[gm<<1],sz[gm<<1],clk=0;
struct cmp{bool operator() (int a,int b){return dfn[a]<dfn[b];}};
void dfs1(int x)
{
sz[x]=1;
for(e *i=f[x];i;i=i->n)
{
if(i->t==fat[x]) continue;
fat[i->t]=x;
dis[i->t]=dis[x]+1;
dfs1(i->t);
sz[x]+=sz[i->t];
if(sz[i->t]>sz[son[x]]) son[x]=i->t;
}
}
void dfs2(int x)
{
dfn[x]=++clk;
top[x]=(x==son[fat[x]]?top[fat[x]]:x);
if(son[x]) dfs2(son[x]);
for(e *i=f[x];i;i=i->n)
{
if(i->t==fat[x]||i->t==son[x]) continue;
dfs2(i->t);
}
}
int LCA(int a,int b)
{
while(top[a]!=top[b])
{
if(dis[top[a]]<dis[top[b]]) a^=b^=a^=b;
a=fat[top[a]];
}
return dis[a]<dis[b]?a:b;
}
struct SAM
{
node *rt,*last;
SAM():rt(new node(0)),last(rt){}
void push_back(char c)
{
c-='a';
node *np=new node(last->dep+1);
node *p=last;
last=np;
for(;p&&!p->s[c];p=p->pa) p->s[c]=np;
if(!p)
{
np->pa=rt;
return;
}
node *q=p->s[c];
if(q->dep==p->dep+1)
{
np->pa=q;
return;
}
node *nq=new node(p->dep+1);
nq->copy(*q);
q->pa=np->pa=nq;
for(;p&&p->s[c]==q;p=p->pa) p->s[c]=nq;
}
void svt()
{
static node *q[gm<<1];
int l=0,r=0;
int ct=0;
q[++r]=rt;
while(l!=r)
{
node *x=q[++l];
++ct;
if(x->no) pos[x->no]=ct;
dep[x->no=ct]=x->dep;
if(x->pa)
{
int u=x->pa->no,v=x->no;
f[u]=new e(v,f[u]);
}
for(char i=0;i<26;++i)
{
node *y=x->s[i];
if(y&&!y->used)
{
y->used=1;
q[++r]=y;
}
}
}
dfs1(1);
dfs2(1);
}
}sam;
char s[gm];
int n,m;
int a[3000001],tot;
int stk[gm<<1],tp;
int fa[gm<<1],cnt[gm<<1];
void solve()
{
int t;
scanf("%d",&t);
for(int i=1;i<=t;++i) scanf("%d",a+i),a[i]=pos[a[i]];
sort(a+1,a+t+1,cmp());
tot=unique(a+1,a+t+1)-a-1;
stk[0]=tp=0;
for(int i=1,j=tot;i<=j;++i)
{
int x=a[i];
cnt[x]=1;
if(!tp)
{
fa[x]=0;
stk[++tp]=x;
}
else
{
int lca=LCA(x,stk[tp]);
while(tp&&dis[lca]<dis[stk[tp]])
{
if(tp==1||dis[stk[tp-1]]<=dis[lca])
{
fa[stk[tp]]=lca;
}
--tp;
}
if(stk[tp]!=lca)
{
a[++tot]=lca;
cnt[lca]=0;
fa[lca]=stk[tp];
stk[++tp]=lca;
}
fa[x]=stk[tp];
stk[++tp]=x;
}
}
sort(a+1,a+tot+1,cmp());
ll ans=0;
for(int i=tot;i>1;--i)
{
int x=a[i];
ans+=ll(cnt[x])*cnt[fa[x]]*dep[fa[x]];
cnt[fa[x]]+=cnt[x];
}
printf("%lld\n",ans);
}
int main()
{
scanf("%d%d%s",&n,&m,s);
for(int i=n-1;~i;--i)
{
sam.push_back(s[i]);
sam.last->no=i+1;
}
sam.svt();
for(int i=1;i<=m;++i)
{
solve();
}
return 0;
}