树套树并不是一个新的数据结构,它是由我们已经学过的数据结构,如线段树、平衡树、树状数组等相互结合而成,难度并不高,重点在于编码。由于是多个数据结构结合,所以对模板的运用需要达到融会贯通的程度。如果还不会一些基础的数据结构,建议先学习基础数据结构后再战。
首先看一道用STL处理的问题,来初步理解一下树套树的思想:(Acwing2488)
不难看出这是一题单点修改、区间查询的问题,用线段树维护这个序列可以达到O(nlogn)的复杂度,但是这个区间查询似乎有些不一样——查询某个数的前驱,这意味着如果仅仅是简单的线段树维护,查询的复杂度将最坏达到查询整个区间找前驱的的复杂度。可以考虑将原线段树中的每个子区间改成set,set查询一个数的前缀的复杂度为O(logn),这样以来总复杂度将控制在
。
那么线段树节点可以改成这样:
struct node {
int l, r;
multiset<int> s;
}tr[N * 4];
建树的时候我们可以在set中插入两个哨兵INF和-INF,这样可以保证一定可以查询到某个数的前驱
void build(int u, int l, int r) {
tr[u] = {l, r};
tr[u].s.insert(INF); tr[u].s.insert(-INF);
for (int i = l; i <= r; i++) tr[u].s.insert(num[i]);
if (l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid); build(u << 1 | 1, mid + 1, r);
}
那么修改某个数的时候,我们可以通过set的删除(erace)添加(insert)函数完成
void modify(int u, int p, int x) {
tr[u].s.erase(tr[u].s.find(num[p]));
tr[u].s.insert(x);
if (tr[u].l == tr[u].r) return;
int mid = tr[u].l + tr[u].r >> 1;
if (p <= mid) modify(u << 1, p, x);
else modify(u << 1 | 1, p, x);
}
查询某个数的时候可以通过set的内置函数low_bound二分查找某个数的前驱
int query(int u, int l, int r, int x) {
if (tr[u].l >= l && tr[u].r <= r) {
auto it = tr[u].s.lower_bound(x);
--it;
return *it;
}
int mid = tr[u].l + tr[u].r >> 1, res = -INF;
if (mid >= l) res = max(res, query(u << 1, l, r, x));
if (mid < r) res = max(res, query(u << 1 | 1, l, r, x));
return res;
}
这样即可完成题目所需要求。
最后附上主函数
int main() {
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> num[i];
build(1, 1, n);
while (m--) {
int op;
cin >> op;
if (op == 1) {
int p, x;
cin >> p >> x;
modify(1, p, x);
num[p] = x;
}
else {
int l, r, x;
cin >> l >> r >> x;
cout << query(1, l, r, x) << endl;
}
}
return 0;
}
看完这题之后想必你对树套树的思想有一个大概的了解了,但是并不是每题都可以通过STL来节省代码量的,如果遇到大量需要手动维护的节点信息,还是手写数据结构比较好,例如(洛谷P3380二逼平衡树):
首先观察到总体是一个区间查询和单点修改的问题,考虑用线段树维护整个序列,再观察每一个查询操作,需要找前驱、后继和排名,可以考虑用平衡树来维护,那么我们可以将每一个线段树的节点替换成一棵splay树,便于维护每个数的排名和前后继。
通过splay一章的学习我们知道查询某个数的排名可以用get_k函数得到,但是如何得到区间内的排名是谁呢?观察到排名是一个连续的单调递增区间,如果我们任取某个数查询排名,可以知晓我们查询的这个数相较于要查的结果是满足大小关系的,整体满足二段性,可以用二分来查找。
其他操作不难,与基础知识一致。
小tips:树套树的函数众多,要同时写出线段树和splay两个数据结构大户的所有函数,这里建议先将所有函数声明,再写函数就不用考虑先后顺序了。
附代码:
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
const int N = 2000010, INF = 1e9;
struct node {
int s[2], p, v;
int size;
void init(int _v, int _p) {
v = _v; p = _p;
size = 1;
}
}tr[N];
int L[N], R[N], T[N], w[N];
int n, m, idx;
void pushup(int u);
void rotate(int x);
void splay(int &root, int x, int k);
void insert(int& root, int v);
int get_k(int root, int x);
void update(int &root, int x, int y);
void build(int u, int l, int r);
int query(int u, int l, int r, int x);
void modify(int u, int p, int x);
int get_suc(int root, int x);
int query_suc(int u, int l, int r, int x);
int get_pre(int root, int x);
int query_pre(int u, int l, int r, int x);
void pushup(int u) {
tr[u].size = tr[tr[u].s[0]].size + tr[tr[u].s[1]].size + 1;
}
void rotate(int x) {
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x;
tr[z].s[tr[z].s[1] == y] = x; tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1]; tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y; tr[y].p = x;
pushup(y); pushup(x);
}
void splay(int &root, int x, int k) {
while (tr[x].p != k) {
int y = tr[x].p, z = tr[y].p;
if (z != k) {
if ((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
else rotate(y);
}
rotate(x);
}
if (!k) root = x;
}
void insert(int& root, int v) {
int u = root, p = 0;
while (u) p = u, u = tr[u].s[tr[u].v < v];
u = ++idx;
if (p) tr[p].s[tr[p].v < v] = u;
tr[u].init(v, p);
splay(root, u, 0);
}
int get_k(int root, int x) { //查找平衡树root的x的排名
int u = root, res = 0;
while (u) {
if (tr[u].v < x) res += tr[tr[u].s[0]].size + 1, u = tr[u].s[1];
else u = tr[u].s[0];
}
return res;
}
void update(int &root, int x, int y) {
int u = root;
while (u) {
if (tr[u].v == x) break;
u = tr[u].s[tr[u].v < x];
}
splay(root, u, 0);
int l = tr[u].s[0], r = tr[u].s[1];
while (tr[l].s[1]) l = tr[l].s[1];
while (tr[r].s[0]) r = tr[r].s[0];
splay(root, l, 0); splay(root, r, l);
tr[r].s[0] = 0;
pushup(r); pushup(l);
insert(root, y);
}
void build(int u, int l, int r) {
L[u] = l; R[u] = r;
insert(T[u], -INF); insert(T[u], INF);
for (int i = l; i <= r; i++) insert(T[u], w[i]);
if (l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid); build(u << 1 | 1, mid + 1, r);
}
int query(int u, int l, int r, int x) { //查询区间[l, r]里x的排名
if (L[u] >= l && R[u] <= r) return get_k(T[u], x) - 1;
int res = 0, mid = L[u] + R[u] >> 1;
if (l <= mid) res += query(u << 1, l, r, x);
if (r > mid) res += query(u << 1 | 1, l, r, x);
return res;
}
void modify(int u, int p, int x) {
update(T[u], w[p], x);
if (L[u] == R[u]) return;
int mid = L[u] + R[u] >> 1;
if (p <= mid) modify(u << 1, p, x);
else modify(u << 1 | 1, p, x);
}
int get_suc(int root, int x) {
int u = root, res = INF;
while(u) {
if (tr[u].v > x) res = min(res, tr[u].v), u = tr[u].s[0];
else u = tr[u].s[1];
}
return res;
}
int get_pre(int root, int x) {
int u = root, res = -INF;
while(u) {
if (tr[u].v < x) res = max(res, tr[u].v), u = tr[u].s[1];
else u = tr[u].s[0];
}
return res;
}
int query_pre(int u, int l, int r, int x) {
if (L[u] >= l && R[u] <= r) return get_pre(T[u], x);
int mid = L[u] + R[u] >> 1, res = -INF;
if (l <= mid) res = max(res, query_pre(u << 1, l, r, x));
if (r > mid) res = max(res, query_pre(u << 1 | 1, l, r, x));
return res;
}
int query_suc(int u, int l, int r, int x) {
if (L[u] >= l && R[u] <= r) return get_suc(T[u], x);
int mid = L[u] + R[u] >> 1, res = INF;
if (l <= mid) res = min(res, query_suc(u << 1, l, r, x));
if (r > mid) res = min(res, query_suc(u << 1 | 1, l, r, x));
return res;
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i ++ ) scanf("%d", &w[i]);
build(1, 1, n);
while (m--) {
int op, a, b, x;
scanf("%d", &op);
if (op == 1) {
scanf("%d%d%d", &a, &b, &x);
printf("%d\n", query(1, a, b, x) + 1);
}
else if (op == 2)
{
scanf("%d%d%d", &a, &b, &x);
int l = 0, r = 1e8;
while (l < r)
{
int mid = l + r + 1 >> 1;
if (query(1, a, b, mid) + 1 <= x) l = mid;
else r = mid - 1;
}
printf("%d\n", r);
}
else if (op == 3)
{
scanf("%d%d", &a, &x);
modify(1, a, x);
w[a] = x;
}
else if (op == 4)
{
scanf("%d%d%d", &a, &b, &x);
printf("%d\n", query_pre(1, a, b, x));
}
else
{
scanf("%d%d%d", &a, &b, &x);
printf("%d\n", query_suc(1, a, b, x));
}
}
return 0;
}
最后再来看一条线段树套线段树的题目:
这题虽短,但涉及的知识点还是很多的,离散化、权值线段树、动态开点、标记持久化等,我们依次分析。
首先我们分析一下操作类型,第一个操作是区间修改,如果把每个位置看作集合,那么就是在一段连续的集合里同时加上一个数,第二个操作是区间查询,那么考虑线段树维护查询操作,内层的树如果用平衡树的话,那么对于集合的维护就很复杂,所以我们可以把外层的树设置为权值线段树。(因为权值的范围为1<<31,而点点个数最多只有50000,所以需要离散化一下。)
权值线段树和线段树维护的东西刚好相反,线段树是建立在下标上的,维护的是每个下标有哪些权值,而权值线段树是建立在权值(数轴)上的,维护的是每个权值在哪些下标里。那么对于第一个操作,线段树上是在一段连续的下标里加上一个权值c,在权值线段树里就是在权值为c的地方加上一段连续的下标。此时的区间修改就变成了单点修改,对于权值c,我们最多修改logn层权值线段树即可。
由于权值线段树里的每个节点都是一棵普通线段树,用来维护下标的修改,所以每执行一次第一个操作,我们就在内层的线段树的a-b所在的位置加上1,这个操作可以用懒标记来维护。
而对于内层的线段树来说,需要维护的操作只有add一个,所以我们可以用一个标记持久化的思想,把add作为线段树的自身属性,这样就不用每次都pushdown操作了,可以节省不少时间。
再其次由于线段树的空间占用非常之大,但是利用率并不高,并且有的节点用过一次之后就不会再被使用了,所以我们可以动态开点,用一个数组存储下所有的节点,相当于节点分配站,当我需要的时候再给他分配。
最后附上完整代码
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
typedef long long LL;
const int N = 50010, P = N * 17 * 17, M = N * 4;
int n, m;
struct Tree
{
int l, r;
LL sum, add;
}tr[P];
int L[M], R[M], T[M], idx;
struct Query
{
int op, a, b, c;
}q[N];
vector<int> nums;
int get(int x)
{
return lower_bound(nums.begin(), nums.end(), x) - nums.begin();
}
void build(int u, int l, int r)
{
L[u] = l, R[u] = r, T[u] = ++ idx;
if (l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}
int intersection(int a, int b, int c, int d)
{
return min(b, d) - max(a, c) + 1;
}
void update(int u, int l, int r, int pl, int pr)
{
tr[u].sum += intersection(l, r, pl, pr);
if (l >= pl && r <= pr)
{
tr[u].add ++ ;
return;
}
int mid = l + r >> 1;
if (pl <= mid)
{
if (!tr[u].l) tr[u].l = ++ idx;
update(tr[u].l, l, mid, pl, pr);
}
if (pr > mid)
{
if (!tr[u].r) tr[u].r = ++ idx;
update(tr[u].r, mid + 1, r, pl, pr);
}
}
void change(int u, int a, int b, int c)
{
update(T[u], 1, n, a, b);
if (L[u] == R[u]) return;
int mid = L[u] + R[u] >> 1;
if (c <= mid) change(u << 1, a, b, c);
else change(u << 1 | 1, a, b, c);
}
LL get_sum(int u, int l, int r, int pl, int pr, int add)
{
if (l >= pl && r <= pr) return tr[u].sum + (r - l + 1LL) * add;
int mid = l + r >> 1;
LL res = 0;
add += tr[u].add;
if (pl <= mid)
{
if (tr[u].l) res += get_sum(tr[u].l, l, mid, pl, pr, add);
else res += intersection(l, mid, pl, pr) * add;
}
if (pr > mid)
{
if (tr[u].r) res += get_sum(tr[u].r, mid + 1, r, pl, pr, add);
else res += intersection(mid + 1, r, pl, pr) * add;
}
return res;
}
int query(int u, int a, int b, int c)
{
if (L[u] == R[u]) return R[u];
int mid = L[u] + R[u] >> 1;
LL k = get_sum(T[u << 1 | 1], 1, n, a, b, 0);
if (k >= c) return query(u << 1 | 1, a, b, c);
return query(u << 1, a, b, c - k);
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 0; i < m; i ++ )
{
scanf("%d%d%d%d", &q[i].op, &q[i].a, &q[i].b, &q[i].c);
if (q[i].op == 1) nums.push_back(q[i].c);
}
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
build(1, 0, nums.size() - 1);
for (int i = 0; i < m; i ++ )
{
int op = q[i].op, a = q[i].a, b = q[i].b, c = q[i].c;
if (op == 1) change(1, a, b, get(c));
else printf("%d\n", nums[query(1, a, b, c)]);
}
return 0;
}