题意
一棵树,每个点对它到根的路上距离小于等于did_idi的点有aia_iai点影响,每个点权值等于收到的影响之和。
每条边有一个出现的概率。每次询问,求某个点所在联通块的权值之和的平方的期望。
思路
首先权值可以倍增加差分搞出来。
然后权值之和的平方的期望,不能等同于权值和的期望的平方,后者的每个和的平方都多乘了一遍概率。
30分做法:对于每次询问,以询问点为根O(n)O(n)O(n)树形DP。
100分做法:2遍树形DP,第一遍处理子树内的情况,第二遍处理子树外的。子树外的期望等于他父亲子树的期望去掉本身的期望再加上父亲子树外的期望,所以可以从上向下DP。
看两种从父亲子树中去掉某个儿子的期望的做法
第一种:
void Delete1(int u, int v, ll w) // 为了更好的观看体验,去掉了取模
{
f2[u] = (f2[u]-w*f2[v]);
f1[u] = (f1[u]-2ll*w*f2[u]*f2[v]-w*f1[v]);
}
第二种:
void Delete2(int u, int v, ll w)
{
f1[u] = (f1[u]-2ll*w*f2[u]*f2[v]+w*f1[v]);
f2[u] = (f2[u]-w*f2[v]);
}
第一种当然是正解,合并的反向操作。
但是第二种把求期望当成简单的算术运算,(x−y)2=x2−2xy+y2(x-y)^2=x^2-2xy+y^2(x−y)2=x2−2xy+y2。实际上Delete2求期望的式子是(1−p)x2+p(x−y)2(1-p)x^2+p(x-y)^2(1−p)x2+p(x−y)2,翻译成自然语言即有(1−p)(1-p)(1−p)的概率不减去儿子子树,p的概率减去儿子子树,然而xxx中包含了没有连y的情况,所以减法是完全错误的。
两种打法的代码(其实差不多):
#include<cstdio>
#include<cstring>
#include<iostream>
#include<vector>
#define ll long long
using namespace std;
const ll mod = 998244353;
const int N = 200010;
const int E = 20;
int n, q, a[N], d[N];
vector<int> to[N], w[N];
int f[N][E];
ll val[N], f1[N], f2[N], ans[N];
void Dfs(int u, int fa) // 倍增和差分
{
f[u][0] = fa;
for (int i = 1; i < E; i++)
f[u][i] = f[f[u][i-1]][i-1];
int tmp = d[u]+1, uu = u;
(val[uu] += a[u]) %= mod;
for (int i = E-1; i >= 0; i--)
if (tmp >= (1<<i))
tmp -= (1<<i), uu = f[uu][i];
(val[uu] += mod-a[u]) %= mod;
for (int i = 0, sz = to[u].size(); i < sz; i++){
int v = to[u][i];
if (v == fa) continue;
Dfs(v, u);
(val[u] += val[v]) %= mod;
}
}
void Merge(int u, int v, ll ww)
{
(f1[u] += 2ll*ww*f2[u]%mod*f2[v]+ww*f1[v]%mod) %= mod;
(f2[u] += ww*f2[v]) %= mod;
}
void Delete(int u, int v, ll ww)
{
(f2[u] -= ww*f2[v]) %= mod;
if (f2[u] < 0) f2[u] += mod;
f1[u] = (f1[u]-2ll*ww*f2[u]%mod*f2[v]-ww*f1[v]%mod)%mod;
if (f1[u] < 0) f1[u] += mod;
}
void Dfs1(int u, int fa) // 子树内
{
f1[u] = val[u]*val[u]%mod;
f2[u] = val[u]%mod;
for (int i = 0, sz = to[u].size(); i < sz; i++){
int v = to[u][i];
ll ww = w[u][i];
if (v == fa) continue;
Dfs1(v, u);
Merge(u, v, ww);
}
}
void Dfs2(int u, int fa) // 子树外
{
ans[u] = f1[u];
for (int i = 0, sz = to[u].size(); i < sz; i++){
int v = to[u][i];
ll ww = w[u][i];
if (v == fa) continue;
Delete(u, v, ww);
Merge(v, u, ww);
Dfs2(v, u);
Delete(v, u, ww); // 这里v也需要还原
Merge(u, v, ww);
}
}
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d%d", &a[i], &d[i]);
for (int i = 1; i < n; i++){
int x, y, z;
scanf("%d%d%d", &x, &y, &z);
to[x].push_back(y);
w[x].push_back(z);
to[y].push_back(x);
w[y].push_back(z);
}
memset(val, 0, sizeof(val));
Dfs(1, 0);
Dfs1(1, 0);
Dfs2(1, 0);
scanf("%d", &q);
for (int i = 1; i <= q; i++){
int x;
scanf("%d", &x);
printf("%lld\n", ans[x]);
}
return 0;
}
#include<cstdio>
#include<cstring>
#include<iostream>
#include<vector>
#define ll long long
using namespace std;
const ll mod = 998244353;
const int N = 200010;
const int E = 20;
int n, q, a[N], d[N];
vector<int> to[N], w[N];
int f[N][E];
ll val[N], f1[N], f2[N], g1[N], g2[N], ans[N];
void Dfs(int u, int fa)
{
f[u][0] = fa;
for (int i = 1; i < E; i++)
f[u][i] = f[f[u][i-1]][i-1];
int tmp = d[u]+1, uu = u;
(val[uu] += a[u]) %= mod;
for (int i = E-1; i >= 0; i--)
if (tmp >= (1<<i))
tmp -= (1<<i), uu = f[uu][i];
(val[uu] += mod-a[u]) %= mod;
for (int i = 0, sz = to[u].size(); i < sz; i++){
int v = to[u][i];
if (v == fa) continue;
Dfs(v, u);
(val[u] += val[v]) %= mod;
}
}
void Dfs1(int u, int fa)
{
f1[u] = val[u]*val[u]%mod;
f2[u] = val[u]%mod;
for (int i = 0, sz = to[u].size(); i < sz; i++){
int v = to[u][i];
ll ww = w[u][i];
if (v == fa) continue;
Dfs1(v, u);
(f1[u] += 2ll*ww*f2[u]%mod*f2[v]+ww*f1[v]%mod) %= mod;
(f2[u] += ww*f2[v]) %= mod;
}
}
void Dfs2(int u, int fa)
{
for (int i = 0, sz = to[u].size(); i < sz; i++){
int v = to[u][i];
ll ww = w[u][i];
if (v == fa) continue;
ll tmp2 = ((f2[u]-ww*f2[v])%mod+mod)%mod;
ll tmp1 = (f1[u]-2ll*ww*tmp2%mod*f2[v]%mod-ww*f1[v]%mod+mod)%mod;
g1[v] = ww*(tmp1+2ll*tmp2%mod*g2[u]%mod+g1[u]%mod)%mod;
g2[v] = ww*(tmp2+g2[u])%mod;
Dfs2(v, u);
}
}
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d%d", &a[i], &d[i]);
for (int i = 1; i < n; i++){
int x, y, z;
scanf("%d%d%d", &x, &y, &z);
to[x].push_back(y);
w[x].push_back(z);
to[y].push_back(x);
w[y].push_back(z);
}
memset(val, 0, sizeof(val));
Dfs(1, 0);
Dfs1(1, 0);
memset(g1, 0, sizeof(g1));
memset(g2, 0, sizeof(g2));
Dfs2(1, 0);
scanf("%d", &q);
for (int i = 1; i <= q; i++){
int x;
scanf("%d", &x);
printf("%lld\n", (f1[x]+g1[x]+2ll*f2[x]*g2[x])%mod);
}
return 0;
}
套路。。。