算法简介:
树上启发式合并 (DSU on Tree) 是一种能够在 Θ(nlog2n) Θ ( n log 2 n ) 的时间里快速统计子树信息的小技巧。大致的思想是:将小的子树的信息合并到大的子树的信息中。具体的实现步骤一般是先 DFS 不是最大的子树并将该子树信息清空,然后 DFS 最大的子树并保留子树信息,在根据保留下来的最大子树信息算出原子树的信息。
例题一:子树颜色统计
题目大意:给定一棵有根树,每个节点有个颜色 coli c o l i ,求每个点的子树中有几个点的颜色 =qryi = q r y i 。
首先,我们想到用 cnt c n t 数组维护每个点子树中每种颜色出现的次数。这样的时间复杂度显然是 Θ(n2) Θ ( n 2 ) 的(在树的形态是一条链达到上界)。我们是用刚才的技巧去优化这个算法。时间复杂度 Θ(nlog2n) Θ ( n log 2 n ) 。
#include <cstdio>
const int maxn = 1000005;
int n, col[maxn], qry[maxn], ans[maxn];
int m, ter[maxn], nxt[maxn], lnk[maxn];
int skip, cnt[maxn], sz[maxn], ch[maxn];
void addedge(int u, int v) {
ter[++m] = v;
nxt[m] = lnk[u];
lnk[u] = m;
}
void gsz(int u) {
sz[u] = 1;
for (int i = lnk[u]; i; i = nxt[i]) {
gsz(ter[i]);
sz[u] += sz[ter[i]];
if (sz[ter[i]] > sz[ch[u]]) {
ch[u] = ter[i];
}
}
}
void edt(int u, int v) {
cnt[col[u]] += v;
for (int i = lnk[u]; i; i = nxt[i]) {
if (ter[i] != skip) {
edt(ter[i], v);
}
}
}
void dfs(int u, bool flag = 0) {
int son = 0;
for (int i = lnk[u]; i; i = nxt[i]) {
if (ter[i] != ch[u]) {
dfs(ter[i]);
}
}
if (ch[u]) {
skip = ch[u];
dfs(ch[u], 1);
}
edt(u, 1);
ans[u] = cnt[qry[u]];
skip = 0;
if (flag == 0) {
edt(u, -1);
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &col[i]);
}
for (int i = 1; i <= n; i++) {
scanf("%d", &qry[i]);
}
for (int u, i = 2; i <= n; i++) {
scanf("%d", &u);
addedge(u, i);
}
gsz(1);
dfs(1);
for (int i = 1; i <= n; i++) {
printf("%d\n", ans[i]);
}
return 0;
}
例题二:CodeForces 600E Lomsat Gelral
题目大意:给定一棵有根树,每个点有一种颜色。我们称一种颜色占领了一个子树当且仅当没有其他颜色在这个子树中出现得比它多。求占领每个子树的所有颜色之和。
还是记录每个颜色的出现次数,顺带维护 maxcnt m a x c n t 和 cnt c n t 值等于 maxcnt m a x c n t 的颜色个数。用启发式合并的技巧优化即可。
#include <cstdio>
const int maxn = 500005;
const int maxm = 1000005;
int n, col[maxn], sz[maxn], ch[maxn];
int m, ter[maxm], nxt[maxm], lnk[maxn];
int skip, maxcnt, cnt[maxn];
long long sum, ans[maxn];
void addedge(int u, int v) {
ter[++m] = v;
nxt[m] = lnk[u];
lnk[u] = m;
}
void dfs(int u, int p) {
sz[u] = 1;
for (int v, i = lnk[u]; i; i = nxt[i]) {
v = ter[i];
if (v == p) continue;
dfs(v, u);
sz[u] += sz[v];
if (sz[v] > sz[ch[u]]) {
ch[u] = v;
}
}
}
void update(int u, int p, int x) {
cnt[col[u]] += x;
if (x > 0 && cnt[col[u]] > maxcnt) {
maxcnt = cnt[col[u]], sum = col[u];
} else if (x > 0 && cnt[col[u]] == maxcnt) {
sum += col[u];
}
for (int v, i = lnk[u]; i; i = nxt[i]) {
v = ter[i];
if (v == p || v == skip) continue;
update(v, u, x);
}
}
void solve(int u, int p, bool flag = 0) {
for (int v, i = lnk[u]; i; i = nxt[i]) {
v = ter[i];
if (v == p || v == ch[u]) continue;
solve(v, u);
}
if (ch[u]) {
solve(ch[u], u, 1);
skip = ch[u];
}
update(u, p, 1);
ans[u] = sum;
skip = 0;
if (flag == 0) {
update(u, p, -1);
maxcnt = sum = 0;
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &col[i]);
}
for (int u, v, i = 1; i < n; i++) {
scanf("%d %d", &u, &v);
addedge(u, v);
addedge(v, u);
}
dfs(1, 0);
solve(1, 0);
for (int i = 1; i <= n; i++) {
printf("%I64d%c", ans[i], " \n"[i == n]);
}
return 0;
}
例题三:CodeForces 741D Tree and Paths
我们称一个字符串 “秀的” 当且仅当重排它的字符可以组成一个回文串。给出一个有根树,每条边上有一个 a
到 v
之间的字符,求每个点的子树中所有简单路径可以组成的 “秀的” 字符串中的最长长度。
一个字符串是 “秀的” 等价于所有的字母出现次数有 0 0 或 个为奇数。我们给第 i i 个字母赋一个权值 ,那么一个字符串是 “秀的” 等价于所有的字母权值的异或和为 0 0 或 。问题转化成了:对于每个节点 u u ,在他的子树中找出 和 b b ,使得 并且 dep[a]+dep[b]−2∗dep[u] d e p [ a ] + d e p [ b ] − 2 ∗ d e p [ u ] 最大。
我们对于 u u 的每个子树,先更新答案,再更新最大值,这样就不会出现 和 b b 的 LCA 不是 的情况了。再用启发式合并优化即可。具体细节请见代码。
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
const int maxn = 1000005;
const int maxm = 1 << 22 | 2;
const int ninf = 0xc0c0c0c0;
int n, val[maxn];
int m, ter[maxn], nxt[maxn], lnk[maxn];
int sz[maxn], ch[maxn], wei[maxn], dep[maxn];
int skip, curdep, curans, mx[maxm], dp[maxn];
void addedge(int u, int v, int w) {
ter[++m] = v;
nxt[m] = lnk[u];
lnk[u] = m;
val[v] = w;
}
void dfs(int u, int p) {
sz[u] = 1;
for (int v, i = lnk[u]; i; i = nxt[i]) {
v = ter[i];
if (v == p) continue;
wei[v] = wei[u] ^ val[v];
dep[v] = dep[u] + 1;
dfs(v, u);
sz[u] += sz[v];
if (sz[v] > sz[ch[u]]) {
ch[u] = v;
}
}
}
void clear(int u) {
mx[wei[u]] = ninf;
}
void update(int u) {
curans = max(curans, mx[wei[u]] + dep[u] - 2 * curdep);
for (int i = 0; i < 22; i++) {
curans = max(curans, mx[wei[u] ^ (1 << i)] + dep[u] - 2 * curdep);
}
}
void insert(int u) {
mx[wei[u]] = max(mx[wei[u]], dep[u]);
}
template<void(*func)(int)> void edit(int u, int p) {
func(u);
for (int v, i = lnk[u]; i; i = nxt[i]) {
v = ter[i];
if (v == p) continue;
edit<func>(v, u);
}
}
void solve(int u, int p, bool flag = 0) {
for (int v, i = lnk[u]; i; i = nxt[i]) {
v = ter[i];
if (v == p || v == ch[u]) continue;
solve(v, u);
}
if (ch[u]) {
solve(ch[u], u, 1);
skip = ch[u];
}
curdep = dep[u];
for (int v, i = lnk[u]; i; i = nxt[i]) {
v = ter[i];
if (v == p) continue;
dp[u] = max(dp[u], dp[v]);
if (v == ch[u]) continue;
edit<update>(v, u);
edit<insert>(v, u);
}
update(u);
insert(u);
dp[u] = max(dp[u], curans);
skip = 0;
if (flag == 0) {
edit<clear>(u, p);
curans = ninf;
}
}
int main() {
memset(mx, 0xc0, sizeof(mx));
scanf("%d", &n);
char s[5];
for (int u, i = 2; i <= n; i++) {
scanf("%d %s", &u, s);
addedge(u, i, 1 << (s[0] - 'a'));
}
dfs(1, 0);
solve(1, 0);
for (int i = 1; i <= n; i++) {
printf("%d%c", dp[i], " \n"[i == n]);
}
return 0;
}