引子
如果把修改扔开,这就是一道十分经典的树形DP入门题。
然而加上修改,这道题瞬间毒瘤了起来。
先把转移方程摆在下面:
f
[
i
]
[
0
]
=
∑
j
∈
s
o
n
m
a
x
(
f
[
j
]
[
0
]
,
f
[
j
]
[
1
]
)
f
[
i
]
[
1
]
=
∑
j
∈
s
o
n
f
[
j
]
[
0
]
+
w
[
i
]
f[i][0]=\sum_{j \in son}max(f[j][0],f[j][1]) \\ f[i][1]=\sum_{j \in son} f[j][0]+w[i]
f[i][0]=j∈son∑max(f[j][0],f[j][1])f[i][1]=j∈son∑f[j][0]+w[i]
我们首先考虑处理一种特殊的情况——树形成了一条链。这样一来,这个转移式就可以化简:
f
[
i
]
[
0
]
=
m
a
x
(
f
[
j
]
[
0
]
,
f
[
j
]
[
1
]
)
f
[
i
]
[
1
]
=
f
[
j
]
[
0
]
+
w
[
i
]
f[i][0]=max(f[j][0],f[j][1]) \\ f[i][1]=f[j][0]+w[i]
f[i][0]=max(f[j][0],f[j][1])f[i][1]=f[j][0]+w[i]
这样一来,我们尝试用一个矩阵运算来转移:
(
0
0
w
[
i
]
−
∞
)
o
p
(
f
[
j
]
[
0
]
f
[
j
]
[
1
]
)
=
(
f
[
i
]
[
0
]
f
[
i
]
[
1
]
)
\begin{pmatrix} 0&0 \\ w[i]& -\infin \end{pmatrix}op \begin{pmatrix}f[j][0] \\ f[j][1] \end{pmatrix}=\begin{pmatrix} f[i][0] \\ f[i][1] \end{pmatrix}
(0w[i]0−∞)op(f[j][0]f[j][1])=(f[i][0]f[i][1])
其中,
A
o
p
B
=
C
A\ op\ B=C
A op B=C定义为:
C
i
,
j
=
max
{
A
i
,
k
+
B
k
,
j
∣
k
∈
[
1
,
n
]
}
C_{i,j}=\max\{ A_{i,k}+B_{k,j}| k \in [1,n]\}
Ci,j=max{Ai,k+Bk,j∣k∈[1,n]}
那么,对于一个修改操作,就可以直接更新这个矩阵,用线段树维护区间
o
p
op
op,这样就解决了这个问题。
现在考虑一般的树。对于一个节点
i
i
i,我们暂时先考虑其中的一个儿子
j
j
j,那么就能得到:
(
f
′
[
i
]
[
0
]
f
′
[
i
]
[
0
]
f
′
[
i
]
[
1
]
−
∞
)
o
p
(
f
[
j
]
[
0
]
f
[
j
]
[
1
]
)
=
(
f
[
i
]
[
0
]
f
[
i
]
[
1
]
)
\begin{pmatrix} f'[i][0]&f'[i][0] \\ f'[i][1]&-\infin \end{pmatrix} op \begin{pmatrix}f[j][0] \\ f[j][1]\end{pmatrix}=\begin{pmatrix} f[i][0] \\ f[i][1] \end{pmatrix}
(f′[i][0]f′[i][1]f′[i][0]−∞)op(f[j][0]f[j][1])=(f[i][0]f[i][1])
f
′
f'
f′表示考虑除
j
j
j外的儿子的答案。
在这里,
j
j
j的地位如此特殊,便不难想到对这棵树做树链剖分,
j
j
j就是
i
i
i的重儿子。那么对于重链上的信息,模仿上述序列的方法维护;而轻儿子的信息则暴力上传即可。
接下来正式讲讲实现。
对于修改操作,我们直接修改目标位置的矩阵。此时,只有该位置的祖先的答案才有可能变化。于是,修改矩阵以后,更新它所在重链的信息,再跳转至新的重链上。此时,原来重链的顶端就变为了该点的轻儿子,故暴力上传信息即可。之后重复这一过程。
对于查询操作,由于每一处的最佳决策均转移至根节点所在重链上,所直接查询根节点所在重链即可。总的时间复杂度为
O
(
n
log
2
2
n
)
O(n\log_2^2n)
O(nlog22n)
#include<bits/stdc++.h>
#define lch (i << 1)
#define rch ((i << 1) | 1)
#define mid ((t[i].l + t[i].r) >> 1)
using namespace std;
const int mn = 100005, inf = -(1 << 30);
struct matrix{
int a[2][2];
matrix() {memset(a, 0, sizeof a);}
matrix(int x, int y, int z, int w) {a[0][0] = x, a[0][1] = y, a[1][0] = z, a[1][1] = w;}
matrix operator* (const matrix b) const
{
matrix ret;
for(int i = 0; i < 2; i++)
for(int j = 0; j < 2; j++)
for(int k = 0; k < 2; k++)
ret.a[i][j] = max(ret.a[i][j], a[i][k] + b.a[k][j]);
return ret;
}
}tmp[mn];
struct seg{
int l, r;
matrix val;
}t[mn << 3];
struct edge{
int to, nxt;
}e[mn << 1];
int fir[mn], cnt;
int v[mn], f[mn][2];
int fa[mn], siz[mn], son[mn], top[mn], num[mn], pos[mn], bot[mn], times;
inline void addedge(int a, int b) {e[++cnt] = (edge) {b, fir[a]}, fir[a] = cnt;}
inline int getint()
{
int ret = 0, flg = 1; char c;
while((c = getchar()) < '0' || c > '9')
if(c == '-') flg = -1;
while(c >= '0' && c <= '9')
ret = ret * 10 + c - '0', c = getchar();
return ret * flg;
}
void dfs1(int s, int f)
{
fa[s] = f, siz[s] = 1;
for(int i = fir[s]; i; i = e[i].nxt)
{
int t = e[i].to;
if(t != f)
{
dfs1(t, s), siz[s] += siz[t];
if(siz[son[s]] < siz[t]) son[s] = t;
}
}
}
void dfs2(int s)
{
num[s] = ++times, pos[times] = bot[s] = s;
if(son[s]) top[son[s]] = top[s], dfs2(son[s]), bot[s] = bot[son[s]];
for(int i = fir[s]; i; i = e[i].nxt)
{
int t = e[i].to;
if(t != fa[s] && t != son[s])
top[t] = t, dfs2(t);
}
}
void dfs(int s)
{
f[s][1] = v[s];
for(int i = fir[s]; i; i = e[i].nxt)
{
int t = e[i].to;
if(t != fa[s])
dfs(t), f[s][0] += max(f[t][0], f[t][1]), f[s][1] += f[t][0];
}
}
void make_tree(int i, int l, int r)
{
t[i] = (seg) {l, r, matrix(0, 0, 0, 0)};
if(l == r)
{
int s = pos[l], g0 = 0, g1 = v[s];
for(int i = fir[s]; i; i = e[i].nxt)
{
int t = e[i].to;
if(t != fa[s] && t != son[s])
g0 += max(f[t][0], f[t][1]), g1 += f[t][0];
}
tmp[l] = t[i].val = matrix(g0, g0, g1, inf);
return;
}
make_tree(lch, l, mid), make_tree(rch, mid + 1, r), t[i].val = t[lch].val * t[rch].val;
}
matrix getans(int i, int l, int r)
{
if(t[i].l == l && t[i].r == r) return t[i].val;
if(r <= mid) return getans(lch, l, r);
else if(l > mid) return getans(rch, l, r);
else return getans(lch, l, mid) * getans(rch, mid + 1, r);
}
void edit_tree(int i, int p)
{
if(t[i].l == p && t[i].r == p) {t[i].val = tmp[p]; return;}
if(p <= mid) edit_tree(lch, p);
else edit_tree(rch, p);
t[i].val = t[lch].val * t[rch].val;
}
void edit(int p, int w)
{
tmp[num[p]].a[1][0] += w - v[p], v[p] = w;
while(p)
{
matrix a = getans(1, num[top[p]], num[bot[p]]);
edit_tree(1, num[p]);
matrix b = getans(1, num[top[p]], num[bot[p]]);
p = fa[top[p]];
if(!p) return;
int x = num[p], g0 = a.a[0][0], g1 = a.a[1][0], f0 = b.a[0][0], f1 = b.a[1][0];
tmp[x].a[0][0] = tmp[x].a[0][1] = tmp[x].a[0][0] + max(f0, f1) - max(g0, g1),
tmp[x].a[1][0] += f0 - g0;
}
}
int main()
{
int n = getint(), m = getint(), a, b;
for(int i = 1; i <= n; i++)
v[i] = getint();
for(int i = 1; i < n; i++)
a = getint(), b = getint(), addedge(a, b), addedge(b, a);
dfs1(1, 0), top[1] = 1, dfs2(1), dfs(1), make_tree(1, 1, n);
while(m--)
{
a = getint(), b = getint(), edit(a, b);
matrix ans = getans(1, num[1], num[bot[1]]);
printf("%d\n", max(ans.a[0][0], ans.a[1][0]));
}
}