题意
给一棵树, n ≤ 1 0 5 n\le 10^5 n≤105。
对每个点,先钦定他为根,然后把与根的距离超过 k ( k ≤ 10 ) k(k\le 10) k(k≤10) 的点删掉。对于剩下的树,求两个值:
- 树上点的个数
- 树上所有点的子树大小的乘积
思路
简单 毒瘤树型 DP。
状态很显然, f [ u ] [ k ] f[u][k] f[u][k] 表示 u u u 为根的大小为 k k k 的子树的点数, 而 g [ u ] [ k ] g[u][k] g[u][k] 表示除了 u u u 之外的所有点子树大小的乘积。
因为有贡献的点不只是在子树中,所以还要换根 DP 。在换根 DP 的时候,有一个小小的 trick 。在递归之前, f , g f,g f,g 中存的是 u u u 的最终答案。而在进入某个子树之前,先减去这棵子树的权值,让 f , g f,g f,g 的定义变为 u u u 对这个将要递归进去的儿子的贡献,那么转移就相对方便了。
注意明确 DP 数组的定义,并且尽量简化转移。
代码
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10, M = N<<1, K = 10 + 5, mod = 1e9 + 7;
int n, k;
int hh[N], ecnt, v[M], nxt[M];
int f[N][K], g[N][K];
template<class T>inline void read(T &x){
x = 0; char c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) x = (x<<3)+(x<<1)+c-'0', c = getchar();
}
template<class T>inline void wr(T x){
if (x > 9) wr(x/10);
putchar(x%10+'0');
}
int inv(int x){
if (x == 0) return 1;
int y = mod-2, ret = 1;
while (y){
if (y&1) ret = 1LL*ret*x%mod;
x = 1LL*x*x%mod;
y >>= 1;
}
return ret;
}
void _add(int x, int y){
nxt[++ecnt] = hh[x]; v[ecnt] = y;
hh[x] = ecnt;
}
void dfs(int u, int fa){
for (int i = 0; i <= k; ++ i){
f[u][i] = 1;
g[u][i] = 1;
}
for (int i = hh[u]; i; i = nxt[i])
if (v[i] != fa){
dfs(v[i], u);
for (int j = 1; j <= k; ++ j){
f[u][j] += f[v[i]][j-1];
g[u][j] = 1LL*g[u][j]*g[v[i]][j-1]%mod*f[v[i]][j-1]%mod;
}
}
}
void dfs1(int u, int fa){
if (u != 1){
for (int i = 1; i <= k; ++ i){
f[u][i] += f[fa][i-1];
g[u][i] = 1LL*g[u][i]*g[fa][i-1]%mod*f[fa][i-1]%mod;
}
}
int tmp1[K], tmp2[K];
for (int i = hh[u]; i; i = nxt[i])
if (v[i] != fa){
for (int j = 1; j <= k; ++ j){
f[u][j] -= f[v[i]][j-1];
g[u][j] = 1LL*g[u][j]*inv(1LL*g[v[i]][j-1]*f[v[i]][j-1]%mod)%mod;
tmp1[j] = f[v[i]][j-1];
tmp2[j] = 1LL*g[v[i]][j-1]*f[v[i]][j-1]%mod;
}
dfs1(v[i], u);
for (int j = 1; j <= k; ++ j){
f[u][j] += tmp1[j];
g[u][j] = 1LL*g[u][j]*tmp2[j]%mod;
}
}
}
int main()
{
read(n); read(k);
ecnt = 1;
for (int i = 1; i < n; ++ i){
int x, y;
read(x); read(y);
_add(x, y); _add(y, x);
}
dfs(1, 0);
dfs1(1, 0);
for (int i = 1; i <= n; ++ i)
wr(f[i][k]), putchar(' ');
puts("");
for (int i = 1; i <= n; ++ i)
wr(1LL*f[i][k]*g[i][k]%mod), putchar(' ');
puts("");
return 0;
}