题意
有两棵树 T1T_1T1 和 T2T_2T2,大小分别为 n1,n2n_1,n_2n1,n2。构造一个新图,其中的每个节点有二元组 (u,v)(1≤u≤n1,1≤v≤n2)(u,v)(1\le u\le n_1,1\le v\le n_2)(u,v)(1≤u≤n1,1≤v≤n2) 表示。(u,v1),(u,v2)(u,v_1),(u,v_2)(u,v1),(u,v2) 相邻当且仅当在 T2T_2T2 中 v1,v2v_1,v_2v1,v2 相邻。(u1,v),(u2,v)(u_1,v),(u_2,v)(u1,v),(u2,v) 相邻当且仅当在 T1T_1T1 中 u1,u2u_1,u_2u1,u2 相邻。问新图中有多少个不同的长度为 kkk 的环。
n1,n2≤4000,k≤75n_1,n_2\le 4000,k\le 75n1,n2≤4000,k≤75。
分析
实际上就是在每棵树上分别走。如果我们能对每棵树求出stepistep_istepi表示长度为iii的环的数量的话,就可以很容易求出答案,问题在于stepistep_istepi怎么求。
考虑点分治,然后求所有经过分治中心ccc的环。
设fi,xf_{i,x}fi,x表示从ccc开始走了iii步走到xxx,且除了一开始以外不经过ccc的方案数,gi,xg_{i,x}gi,x表示从ccc开始走了iii步走到xxx的方案数。
转移比较显然,那么对于某个点xxx,从xxx开始走,经过ccc的大小为iii的环的数量就是∑j=0ifj,x∗gi−j,x\sum_{j=0}^if_{j,x}*g_{i-j,x}j=0∑ifj,x∗gi−j,x
这里可以看成是枚举第一次到达ccc是在第几步,然后后面随便走。需要特判xxx就是分治中心的情况。
这样的话总的时间复杂度就是O(nk2logn)O(nk^2\log n)O(nk2logn),如果用FFT来优化卷积的话可以做到O(nklognlogk)O(nk\log n\log k)O(nklognlogk)。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
typedef long long LL;
const int N=4005;
const int M=80;
const int MOD=998244353;
int m,jc[M],ny[M];
struct Tree
{
int n,cnt,last[N],f[M][N],g[M][N],size[N],ans[M],w[N],tot,a[N],sum,root;
bool vis[N];
struct edge{int to,next;}e[N*2];
void addedge(int u,int v)
{
e[++cnt].to=v;e[cnt].next=last[u];last[u]=cnt;
e[++cnt].to=u;e[cnt].next=last[v];last[v]=cnt;
}
void get_root(int x,int fa)
{
size[x]=1;w[x]=0;
for (int i=last[x];i;i=e[i].next)
{
if (e[i].to==fa||vis[e[i].to]) continue;
get_root(e[i].to,x);
size[x]+=size[e[i].to];
w[x]=std::max(w[x],size[e[i].to]);
}
w[x]=std::max(w[x],sum-size[x]);
if (!root||w[x]<w[root]) root=x;
}
void get(int x,int fa)
{
a[++tot]=x;size[x]=1;
for (int i=last[x];i;i=e[i].next)
if (e[i].to!=fa&&!vis[e[i].to]) get(e[i].to,x),size[x]+=size[e[i].to];
}
void calc(int x)
{
tot=0;get(x,0);
for (int i=1;i<=tot;i++) f[0][a[i]]=0;
f[0][x]=g[0][x]=1;
for (int i=1;i<=m;i++)
for (int j=1;j<=tot;j++)
{
int y=a[j];
f[i][y]=g[i][y]=0;
for (int k=last[y];k;k=e[k].next)
{
if (vis[e[k].to]) continue;
(g[i][y]+=g[i-1][e[k].to])%=MOD;
if (y!=x) (f[i][y]+=f[i-1][e[k].to])%=MOD;
}
}
for (int i=1;i<=tot;i++)
{
int y=a[i];
if (y==x)
{
for (int j=0;j<=m;j++) (ans[j]+=g[j][x])%=MOD;
continue;
}
for (int j=0;j<=m;j++)
for (int k=0;j+k<=m;k++)
(ans[j+k]+=(LL)f[j][y]*g[k][y]%MOD)%=MOD;
}
vis[x]=1;
for (int i=last[x];i;i=e[i].next)
{
if (vis[e[i].to]) continue;
root=0;sum=size[e[i].to];
get_root(e[i].to,x);
calc(root);
}
}
void solve()
{
sum=n;root=0;
get_root(1,0);
calc(root);
}
}t1,t2;
int C(int n,int m)
{
return (LL)jc[n]*ny[m]%MOD*ny[n-m]%MOD;
}
int main()
{
scanf("%d%d%d",&t1.n,&t2.n,&m);
jc[0]=jc[1]=ny[0]=ny[1]=1;
for (int i=2;i<=m;i++) jc[i]=(LL)jc[i-1]*i%MOD,ny[i]=(LL)(MOD-MOD/i)*ny[MOD%i]%MOD;
for (int i=2;i<=m;i++) ny[i]=(LL)ny[i-1]*ny[i]%MOD;
for (int i=1;i<t1.n;i++)
{
int x,y;scanf("%d%d",&x,&y);
t1.addedge(x,y);
}
for (int i=1;i<t2.n;i++)
{
int x,y;scanf("%d%d",&x,&y);
t2.addedge(x,y);
}
t1.solve();t2.solve();
int s=0;
for (int i=0;i<=m;i++)
(s+=(LL)t1.ans[i]*t2.ans[m-i]%MOD*C(m,i)%MOD)%=MOD;
printf("%d",s);
return 0;
}