Description
给出一棵nnn个节点的树,初始每个点为白色,对于一个点集SSS,定义f(s)f(s)f(s)为:将SSS中点染成黑色,如果存在任何白点在两个黑点路径之间则f(S)=0f(S)=0f(S)=0,否则选取一个路径集合,使得该路径集合中每条路径斗不包含黑点,之后把路径上的点染成红色,此时f(S)f(S)f(S)为使得所有黑点的邻接点为黑点或红点的路径集合数量,要求对于每个非空子集SSS计算f(S)f(S)f(S)之和
Input
第一行一整数nnn表示点数,之后n−1n-1n−1行每行输入两个整数表示一条树边
(1≤n≤2⋅105)(1\le n\le 2\cdot 10^5)(1≤n≤2⋅105)
Output
输出f(S)f(S)f(S)之和,结果模998244353998244353998244353
Sample Input
2
1 2
Sample Output
3
Solution
第二步所选边集会将这棵树分成若干没有红点的连通分支,每个连通分支均可选来做第一步所选点集,连通分支数=未被染成红色点的数量-两个端点均未被染成红色的边的数量,那么单独考虑每个点和每条边对答案的贡献
1.一个点uuu对答案的贡献即为选出不经过该点的边集数量,令num[x]=2x(x+1)2num[x]=2^{\frac{x(x+1)}{2}}num[x]=22x(x+1)为有xxx个点的连通分支中选出一个边集的方案数,那么点uuu对答案的贡献即为
f(u)=num[n−Sizeu]∏v∈son(u)num[Sizev]
f(u)=num[n-Size_u]\prod\limits_{v\in son(u)}num[Size_v]
f(u)=num[n−Sizeu]v∈son(u)∏num[Sizev]
2.一条边u→vu\rightarrow vu→v对答案的贡献为(假设uuu是vvv的父亲,将这条边的贡献记录在vvv点)
g(v)=num[n−Sizeu]∏s∈son(u)−{v}num[Sizes]∏t∈son(v)num[Sizet]
g(v)=num[n-Size_u]\prod\limits_{s\in son(u)-\{v\}}num[Size_s]\prod\limits_{t\in son(v)}num[Size_t]
g(v)=num[n−Sizeu]s∈son(u)−{v}∏num[Sizes]t∈son(v)∏num[Sizet]
也即
g(v)=f(u)num[Sizev]f(v)num[n−Sizev]
g(v)=\frac{f(u)}{num[Size_v]}\frac{f(v)}{num[n-Size_v]}
g(v)=num[Sizev]f(u)num[n−Sizev]f(v)
树形DPDPDP一遍,答案即为∑i=1n(f(i)−g(i))\sum\limits_{i=1}^n(f(i)-g(i))i=1∑n(f(i)−g(i)),线性预处理num[x]num[x]num[x]以及num−1[x]num^{-1}[x]num−1[x],时间复杂度O(n)O(n)O(n)
Code
#include<cstdio>
#include<vector>
using namespace std;
typedef long long ll;
#define maxn 200005
#define mod 998244353
#define inv2 499122177
int mul(int x,int y)
{
ll z=1ll*x*y;
return z-z/mod*mod;
}
int add(int x,int y)
{
x+=y;
if(x>=mod)x-=mod;
return x;
}
int b[maxn],c[maxn],ib[maxn],ic[maxn];
void init(int n=2e5)
{
b[0]=1;
for(int i=1;i<=n;i++)b[i]=mul(2,b[i-1]);
c[0]=1;
for(int i=1;i<=n;i++)c[i]=mul(b[i],c[i-1]);
ib[0]=1;
for(int i=1;i<=n;i++)ib[i]=mul(inv2,ib[i-1]);
ic[0]=1;
for(int i=1;i<=n;i++)ic[i]=mul(ib[i],ic[i-1]);
}
int n,Size[maxn],f[maxn],g[maxn];
vector<int>e[maxn];
void dfs(int u,int fa)
{
Size[u]=1;
f[u]=1;
for(int i=0;i<e[u].size();i++)
{
int v=e[u][i];
if(v==fa)continue;
dfs(v,u);
f[u]=mul(f[u],c[Size[v]]);
Size[u]+=Size[v];
}
f[u]=mul(f[u],c[n-Size[u]]);
for(int i=0;i<e[u].size();i++)
{
int v=e[u][i];
if(v==fa)continue;
g[v]=1;
g[v]=mul(f[v],ic[n-Size[v]]);//son of v
g[v]=mul(g[v],mul(f[u],ic[Size[v]]));//brother of v&&out of u
}
}
int main()
{
init();
scanf("%d",&n);
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
e[u].push_back(v),e[v].push_back(u);
}
dfs(1,0);
int ans=0;
for(int i=1;i<=n;i++)ans=add(ans,add(f[i],mod-g[i]));
printf("%d\n",ans);
return 0;
}