1089: [SCOI2003]严格n元树
Time Limit: 1 Sec Memory Limit: 162 MBSubmit: 1542 Solved: 764
[ Submit][ Status][ Discuss]
Description
如果一棵树的所有非叶节点都恰好有n个儿子,那么我们称它为严格n元树。如果该树中最底层的节点深度为d
(根的深度为0),那么我们称它为一棵深度为d的严格n元树。例如,深度为2的严格2元树有三个,如下图:
给出n, d,编程数出深度为d的n元树数目。
Input
仅包含两个整数n, d( 0 < n < = 32, 0 < = d < = 16)
Output
仅包含一个数,即深度为d的n元树的数目。
Sample Input
2 2
【样例输入2】
2 3
【样例输入3】
3 5
Sample Output
3
【样例输出2】
21
【样例输出2】
58871587162270592645034001
HINT
Source
题解:dp+高精度
如果我们考虑在上一层的哪些叶节点上在加一层的话,必然需要考虑上一层叶节点的个数,但是叶节点可能会很多,所以我们无法记录下来。
所以换一种思考方式,在底下加入需要考虑叶子节点的数量,但是如果考虑从上面加入就可以避免了。什么意思呢,就是假设我们已经得到了d-1层的答案,那么我们考虑加入根节点,因为是严格n元树,所以必然有n个儿子,我们可以通过在他的儿子节点上加入d-1层的树来得到d层的树。因为我们要得到d层的,所以必然存在至少一个子节点接的是d-1层的子树。我们当然可以通过枚举选取几个挂d-1层的子树再结合上组合数来更新答案,但是这样太麻烦了。如果不考虑d层的限制,根的每一个儿子都可以自由选择深度为0-(d-1)的每一颗子树,答案为sum[d-1]^n (sum[i]表示深度为0-i的所有子树的答案) 那么不合法的方案数其实就是sum[d-2]^n,因为在0-(d-2)中选择,无论如何都无法满足条件。
f[i]=sum[i-1]^n-sum[i-2]^n 剩下的就是高精度的问题啦。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define N 1003
using namespace std;
int n,m,base[N],ans[N],c[N],f[N][N];
struct data
{
int a[N];
}sum,sum1,num;
void mul(data &x,data y)
{
for (int i=1;i<=y.a[0];i++)
base[i]=ans[i]=y.a[i];
base[0]=y.a[0],ans[0]=y.a[0];
for (int i=2;i<=n;i++) {
memset(c,0,sizeof(c));
for (int j=1;j<=base[0];j++)
for (int k=1;k<=ans[0];k++)
c[j+k-1]+=base[j]*ans[k];
int t=max(base[0],ans[0]);
int x=0;
for (int j=1;j<=t;j++)
{
c[j]=c[j]+x;
x=c[j]/10000;
c[j]%=10000;
}
c[t+1]+=x;
while (c[t+1]) {
t++;
c[t+1]+=c[t]/10000;
c[t]%=10000;
}
for (int j=1;j<=t;j++)
ans[j]=c[j];
ans[0]=t;
}
x.a[0]=ans[0];
for (int i=1;i<=ans[0];i++)
x.a[i]=ans[i];
}
void calc(int f[N],data sum1,data sum)
{
for (int i=1;i<=sum1.a[0];i++)
{
if (sum1.a[i]-sum.a[i]<0){
sum1.a[i+1]--;
sum1.a[i]+=10000;
}
f[i]=sum1.a[i]-sum.a[i];
}
f[0]=sum1.a[0];
while (!f[f[0]]) f[0]--;
}
void add(int f[N],data &x)
{
int t=max(f[0],x.a[0]);
for (int i=1;i<=t;i++)
{
x.a[i]+=f[i];
}
for (int i=1;i<=t;i++)
x.a[i+1]+=x.a[i]/10000,x.a[i]%=10000;
while (x.a[t+1]) {
t++;
x.a[t+1]+=x.a[t]/10000;
x.a[t]%=10000;
}
x.a[0]=t;
}
int main()
{
scanf("%d%d",&n,&m);
if (n==0){
printf("1\n");
return 0;
}
m++;
f[1][1]=1; f[1][0]=1;
sum.a[1]=1; sum.a[0]=1; mul(sum,sum);
f[2][1]=1; f[2][0]=1;
sum1.a[0]=1; sum1.a[1]=2; mul(sum1,sum1); num.a[0]=1; num.a[1]=2;
for (int i=3;i<=m;i++)
{
calc(f[i],sum1,sum);
for (int j=0;j<=sum1.a[0];j++)
sum.a[j]=sum1.a[j];
add(f[i],num);
mul(sum1,num);
}
for (int i=f[m][0];i>=1;i--){
if (i!=f[m][0]){
if (f[m][i]<10) printf("000");
else if (f[m][i]<100) printf("00");
else if (f[m][i]<1000) printf("0");
}
printf("%d",f[m][i]);
}
printf("\n");
}