Hello my friend
题目背景:
分析:树型DP
听说这个是套路,我现在才知道是不是可以直接退个役什么的······显然对于黑点而言,我们需要的是期望次数,而对于白点而言,我们需要的是期望概率,先考虑黑点的贡献,对于以1为根的有根树,令f[i]表示从点i到结束的期望经过的黑点个数,deg[i]表示i的度数。那么:
显然最后可以获得k[1]和b[1],显然b[1]就是f[1]了,所以f[1]在过程中可以不用显示维护,只要维护k[i], b[i]即可。
继续考虑白点的贡献,显然就是到这个白点的概率,定义dp[i]表示到点i的概率,显然不可能不经过父亲节点,那么,令g[i]表示,从fa[i]到i的概率,令f[i]表示从i到fa[i]的概率:

显然,f可以比较简单的求出来,有了f之后g也就非常好处理了,所以原题只需要3遍dfs即可,复杂度O(n)。
Source:
/*
created by scarlyw
*/
#include <cstdio>
#include <string>
#include <algorithm>
#include <cstring>
#include <iostream>
#include <cmath>
#include <cctype>
#include <vector>
#include <set>
#include <queue>
#include <ctime>
#include <bitset>
inline char read() {
static const int IN_LEN = 1024 * 1024;
static char buf[IN_LEN], *s, *t;
if (s == t) {
t = (s = buf) + fread(buf, 1, IN_LEN, stdin);
if (s == t) return -1;
}
return *s++;
}
/*
template<class T>
inline void R(T &x) {
static char c;
static bool iosig;
for (c = read(), iosig = false; !isdigit(c); c = read()) {
if (c == -1) return ;
if (c == '-') iosig = true;
}
for (x = 0; isdigit(c); c = read())
x = ((x << 2) + x << 1) + (c ^ '0');
if (iosig) x = -x;
}
//*/
const int OUT_LEN = 1024 * 1024;
char obuf[OUT_LEN], *oh = obuf;
inline void write_char(char c) {
if (oh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), oh = obuf;
*oh++ = c;
}
template<class T>
inline void W(T x) {
static int buf[30], cnt;
if (x == 0) write_char('0');
else {
if (x < 0) write_char('-'), x = -x;
for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48;
while (cnt) write_char(buf[cnt--]);
}
}
inline void flush() {
fwrite(obuf, 1, oh - obuf, stdout);
}
///*
template<class T>
inline void R(T &x) {
static char c;
static bool iosig;
for (c = getchar(), iosig = false; !isdigit(c); c = getchar())
if (c == '-') iosig = true;
for (x = 0; isdigit(c); c = getchar())
x = ((x << 2) + x << 1) + (c ^ '0');
if (iosig) x = -x;
}
//*/
const int MAXN = 100000 + 10;
const int mod = 998244353;
std::vector<int> edge[MAXN];
int n, x, y, ans;
int d[MAXN], f[MAXN], k[MAXN], b[MAXN], g[MAXN], dp[MAXN], c[MAXN], sum[MAXN];
char s[MAXN];
inline void add(int &x, int t) {
x += t, (x >= mod) ? (x -= mod) : (0);
}
inline int mod_pow(int a, int b) {
int ans = 1;
for (; b; b >>= 1, a = (long long)a * a % mod)
if (b & 1) ans = (long long)ans * a % mod;
return ans;
}
inline void add_edge(int x, int y) {
edge[x].push_back(y), edge[y].push_back(x), d[x]++, d[y]++;
}
inline void read_in() {
scanf("%d%s", &n, s + 1);
for (int i = 1; i <= n; ++i) c[i] = s[i] - '0';
for (int i = 1; i < n; ++i) R(x), R(y), add_edge(x, y);
}
inline void dfs1(int cur, int fa) {
if (d[cur] == 1) {
f[cur] = k[cur] = 0, b[cur] = c[cur];
return ;
}
int cur_k = 1, cur_b = c[cur], cur_f = 0, inv = mod_pow(d[cur], mod - 2);
for (int p = 0; p < edge[cur].size(); ++p) {
int v = edge[cur][p];
if (v != fa) {
dfs1(v, cur), add(cur_k, mod - (long long)inv * k[v] % mod);
add(cur_b, (long long)inv * b[v] % mod), add(cur_f, f[v]);
}
}
f[cur] = mod_pow((d[cur] - cur_f + mod) % mod, mod - 2);
int x = mod_pow(cur_k, mod - 2);
k[cur] = (long long)inv * x % mod, b[cur] = (long long)cur_b * x % mod;
}
inline void dfs2(int cur, int fa) {
if (fa != 0) {
int temp = (((long long)d[fa] - sum[fa] +
f[cur] - g[fa]) % mod + mod) % mod;
g[cur] = mod_pow(temp, mod - 2);
}
for (int p = 0; p < edge[cur].size(); ++p) {
int v = edge[cur][p];
if (v != fa) add(sum[cur], f[v]);
}
for (int p = 0; p < edge[cur].size(); ++p) {
int v = edge[cur][p];
if (v != fa) dfs2(v, cur);
}
}
inline void dfs3(int cur, int fa) {
if (c[cur] == 0) add(ans, dp[cur]);
for (int p = 0; p < edge[cur].size(); ++p) {
int v = edge[cur][p];
if (v != fa) dp[v] = (long long)dp[cur] * g[v] % mod, dfs3(v, cur);
}
}
int main() {
freopen("sad.in", "r", stdin);
freopen("sad.out", "w", stdout);
read_in();
dfs1(1, 0), dfs2(1, 0), dp[1] = 1, dfs3(1, 0);
std::cout << (add(ans, b[1]), ans) << '\n';
return 0;
}
/*
4
1011
1 2
1 3
3 4
*/