题目链接。题目的大意就是在树上修改和查询相关的信息。看到这个其实我们很容易的想到线段树,如果是在数组上也就是线性结构中实现这些操作,其实很简单,就是线段树的基本功能。但是由于数据的的形式的组织的改变,使得线段树不能再直接的使用。那么如何在树上完成这些操作呢?
其实思考问题都是有这样的一个过程:将未知转化为已知。我么知道支持修改和查询,线段是是一个很好的工具。所以我们会有一个很原始的想法,能不能继续使用这种数据结构来维护数据,那么我们就面临一个问题,线段树只能维护线性的数据结构,所以我们就要把树形的数据转化为线性的。
树链剖分:树链剖分就是来做这件事情的。树链剖分将一颗树通过一定的规则将整棵树分解成一条一条的链子,然后把这些链子链接起来,形成一条链子。现在最常用的规则就是轻重链规则,通过节点的子树的数量区分轻节点和重节点,父节点和重节点之间的边叫做重边,依次类推其他的概念。通过两遍dfs找到这些信息。其实这个思想在我看来就是把树形的数据结构通过一定的规则映射到线性的数据结构,然后再把映射好的线性的数据结构放到线段树上,使用线段树来维护数据。嗯,其实我的介绍都是泛泛而谈,没有严格的证明,只是提供一个理解的方法。下边根据BZOJ1036给出树剖的基本代码实现。
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include<algorithm>
#pragma warning(disable:4996)
using namespace std;
#define MAXN 30010
int head[MAXN];
struct node
{
int y, next;
}edge[2 * MAXN];
int id[MAXN], son[MAXN], dep[MAXN], fa[MAXN], siz[MAXN], a[MAXN];
int l, x, y, n, q, tot;
char s[10];
int pre[MAXN], top[MAXN];
void add(int x, int y)
{
l++;
edge[l].y = y;
edge[l].next = head[x];
head[x] = l;
}
void dfs1(int x, int f)
{
int y;
fa[x] = f;
son[x] = 0;
siz[x] = 1;
for (int i = head[x]; i != -1; i = edge[i].next)
if (edge[i].y != f)
{
y = edge[i].y;
dep[y] = dep[x] + 1;
dfs1(y, x);
siz[x] += siz[y];
if (siz[son[x]]<siz[y])
son[x] = y;
}
}
void dfs2(int x, int tp)
{
int y;
top[x] = tp;
id[x] = ++tot;// tot is the total number
pre[id[x]] = x;
if (son[x]) dfs2(son[x], tp);
for (int i = head[x]; i != -1; i = edge[i].next)
if (edge[i].y != fa[x] && edge[i].y != son[x])
{
y = edge[i].y;
dfs2(y, y);
}
}
struct point
{
int l, r, sum, max;
}tr[4 * MAXN];
void updata(int p)
{
tr[p].sum = tr[p << 1].sum + tr[p << 1 | 1].sum;
tr[p].max = max(tr[p << 1].max, tr[p << 1 | 1].max);
}
void build(int p, int l, int r)
{
tr[p].l = l; tr[p].r = r;
if (l == r) { tr[p].sum = tr[p].max = a[pre[l]]; return; }
int mid = (l + r) >> 1;
build(p << 1, l, mid); build(p << 1 | 1, mid + 1, r);
updata(p);
}
void change(int p, int x, int y)
{
if (tr[p].l == x && tr[p].r == x)
{
tr[p].sum = tr[p].max = y;
return;
}
int mid = (tr[p].l + tr[p].r) >> 1;
if (x <= mid) change(p << 1, x, y);
if (x>mid) change(p << 1 | 1, x, y);
updata(p);
}
int ask_max(int p, int l, int r)
{
if (tr[p].l == l && tr[p].r == r)
return tr[p].max;
int mid = (tr[p].l + tr[p].r) >> 1;
if (r <= mid) return ask_max(p << 1, l, r);
if (l>mid) return ask_max(p << 1 | 1, l, r);
if (l <= mid && r>mid)
{
int s1 = ask_max(p << 1, l, mid);
int s2 = ask_max(p << 1 | 1, mid + 1, r);
return max(s1, s2);
}
}
int ask_sum(int p, int l, int r)
{
if (tr[p].l == l && tr[p].r == r)
return tr[p].sum;
int mid = (tr[p].l + tr[p].r) >> 1;
if (r <= mid) return ask_sum(p << 1, l, r);
if (l>mid) return ask_sum(p << 1 | 1, l, r);
if (l <= mid && r>mid)
{
int s1 = ask_sum(p << 1, l, mid);
int s2 = ask_sum(p << 1 | 1, mid + 1, r);
return s1 + s2;
}
}
int find_max(int x, int y)
{
int f1 = top[x], f2 = top[y], tmp = -0x3f3f3f;//有负数
while (f1 != f2)
{
if (dep[f1]<dep[f2])
{
swap(f1, f2); swap(x, y);
}
tmp = max(tmp, ask_max(1, id[f1], id[x]));
x = fa[f1]; f1 = top[x];
}
if (x == y) return max(tmp, ask_max(1, id[x], id[x]));
if (dep[x]>dep[y]) swap(x, y);
return max(tmp, ask_max(1, id[x], id[y]));
}
int find_sum(int x, int y)
{
int f1 = top[x], f2 = top[y], tmp = 0;
while (f1 != f2)
{
if (dep[f1]<dep[f2])
{
swap(f1, f2); swap(x, y);
}
tmp += ask_sum(1, id[f1], id[x]);
x = fa[f1]; f1 = top[x];
}
if (x == y) return tmp + ask_sum(1, id[x], id[y]);
if (dep[x]>dep[y]) swap(x, y);
return tmp + ask_sum(1, id[x], id[y]);
}
int main()
{
scanf("%d", &n);
memset(head, -1, sizeof(head));
for (int i = 1; i<n; i++)
{
scanf("%d%d", &x, &y);
add(x, y);
add(y, x);
}
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
dfs1(1, 0);
dfs2(1, 1);
build(1, 1, n);
scanf("%d", &q);
while (q--)
{
scanf("%s%d%d", s, &x, &y);
if (s[0] == 'C') change(1, id[x], y);
if (s[0] == 'Q' && s[1] == 'M') printf("%d\n", find_max(x, y));
if (s[0] == 'Q' && s[1] == 'S') printf("%d\n", find_sum(x, y));//if there is X types of queries, we need to write X parts for the queries
}
}