一篇文章学会树套树(禁止套娃)

树套树并不是一个新的数据结构,它是由我们已经学过的数据结构,如线段树、平衡树、树状数组等相互结合而成,难度并不高,重点在于编码。由于是多个数据结构结合,所以对模板的运用需要达到融会贯通的程度。如果还不会一些基础的数据结构,建议先学习基础数据结构后再战。

首先看一道用STL处理的问题,来初步理解一下树套树的思想:(Acwing2488)

不难看出这是一题单点修改、区间查询的问题,用线段树维护这个序列可以达到O(nlogn)的复杂度,但是这个区间查询似乎有些不一样——查询某个数的前驱,这意味着如果仅仅是简单的线段树维护,查询的复杂度将最坏达到查询整个区间找前驱的O(n^2logn)的复杂度。可以考虑将原线段树中的每个子区间改成set,set查询一个数的前缀的复杂度为O(logn),这样以来总复杂度将控制在O(nlog^2n)

那么线段树节点可以改成这样:

struct node {
    int l, r;
    multiset<int> s;
}tr[N * 4];

建树的时候我们可以在set中插入两个哨兵INF和-INF,这样可以保证一定可以查询到某个数的前驱


void build(int u, int l, int r) {
    tr[u] = {l, r};
    tr[u].s.insert(INF); tr[u].s.insert(-INF);
    for (int i = l; i <= r; i++) tr[u].s.insert(num[i]);
    if (l == r) return;
    int mid = l + r >> 1;
    build(u << 1, l, mid); build(u << 1 | 1, mid + 1, r);
}

那么修改某个数的时候,我们可以通过set的删除(erace)添加(insert)函数完成

void modify(int u, int p, int x) {
    tr[u].s.erase(tr[u].s.find(num[p]));
    tr[u].s.insert(x);
    if (tr[u].l == tr[u].r) return;
    int mid = tr[u].l + tr[u].r >> 1;
    if (p <= mid) modify(u << 1, p, x);
    else modify(u << 1 | 1, p, x);
}

查询某个数的时候可以通过set的内置函数low_bound二分查找某个数的前驱

int query(int u, int l, int r, int x) {
    if (tr[u].l >= l && tr[u].r <= r) {
        auto it = tr[u].s.lower_bound(x);
        --it;
        return *it;
    }
    int mid = tr[u].l + tr[u].r >> 1, res = -INF;
    if (mid >= l) res = max(res, query(u << 1, l, r, x));
    if (mid < r) res = max(res, query(u << 1 | 1, l, r, x));
    return res;
}

这样即可完成题目所需要求。

最后附上主函数

int main() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> num[i];
    build(1, 1, n);
    while (m--) {
        int op;
        cin >> op;
        if (op == 1) {
            int p, x;
            cin >> p >> x;
            modify(1, p, x);
            num[p] = x;
        }
        else {
            int l, r, x;
            cin >> l >> r >> x;
            cout << query(1, l, r, x) << endl;
        }
    }
    return 0;
}

看完这题之后想必你对树套树的思想有一个大概的了解了,但是并不是每题都可以通过STL来节省代码量的,如果遇到大量需要手动维护的节点信息,还是手写数据结构比较好,例如(洛谷P3380二逼平衡树):

 首先观察到总体是一个区间查询和单点修改的问题,考虑用线段树维护整个序列,再观察每一个查询操作,需要找前驱、后继和排名,可以考虑用平衡树来维护,那么我们可以将每一个线段树的节点替换成一棵splay树,便于维护每个数的排名和前后继。

通过splay一章的学习我们知道查询某个数的排名可以用get_k函数得到,但是如何得到区间内的排名是谁呢?观察到排名是一个连续的单调递增区间,如果我们任取某个数查询排名,可以知晓我们查询的这个数相较于要查的结果是满足大小关系的,整体满足二段性,可以用二分来查找。

其他操作不难,与基础知识一致。

小tips:树套树的函数众多,要同时写出线段树和splay两个数据结构大户的所有函数,这里建议先将所有函数声明,再写函数就不用考虑先后顺序了。

附代码:

#include <iostream>
#include <algorithm>
#include <cstring>

using namespace std;

const int N = 2000010, INF = 1e9;

struct node {
    int s[2], p, v;
    int size;
    void init(int _v, int _p) {
        v = _v; p = _p;
        size = 1;
    }
}tr[N];
int L[N], R[N], T[N], w[N];
int n, m, idx;

void pushup(int u);
void rotate(int x);
void splay(int &root, int x, int k);
void insert(int& root, int v);
int get_k(int root, int x);
void update(int &root, int x, int y);
void build(int u, int l, int r);
int query(int u, int l, int r, int x);
void modify(int u, int p, int x);
int get_suc(int root, int x);
int query_suc(int u, int l, int r, int x);
int get_pre(int root, int x);
int query_pre(int u, int l, int r, int x);

void pushup(int u) {
    tr[u].size = tr[tr[u].s[0]].size + tr[tr[u].s[1]].size + 1;
}

void rotate(int x) {
    int y = tr[x].p, z = tr[y].p;
    int k = tr[y].s[1] == x;
    tr[z].s[tr[z].s[1] == y] = x; tr[x].p = z;
    tr[y].s[k] = tr[x].s[k ^ 1]; tr[tr[x].s[k ^ 1]].p = y;
    tr[x].s[k ^ 1] = y; tr[y].p = x;
    pushup(y); pushup(x);
}

void splay(int &root, int x, int k) {
    while (tr[x].p != k) {
        int y = tr[x].p, z = tr[y].p;
        if (z != k) {
            if ((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
            else rotate(y);
        }
        rotate(x);
    }
    if (!k) root = x;
}

void insert(int& root, int v) {
    int u = root, p = 0;
    while (u) p = u, u = tr[u].s[tr[u].v < v];
    u = ++idx;
    if (p) tr[p].s[tr[p].v < v] = u;
    tr[u].init(v, p);
    splay(root, u, 0);
}

int get_k(int root, int x) { //查找平衡树root的x的排名
    int u = root, res = 0;
    while (u) {
        if (tr[u].v < x) res += tr[tr[u].s[0]].size + 1, u = tr[u].s[1];
        else u = tr[u].s[0];
    }
    return res;
}

void update(int &root, int x, int y) {
    int u = root;
    while (u) {
        if (tr[u].v == x) break;
        u = tr[u].s[tr[u].v < x];
    }
    splay(root, u, 0);
    int l = tr[u].s[0], r = tr[u].s[1];
    while (tr[l].s[1]) l = tr[l].s[1];
    while (tr[r].s[0]) r = tr[r].s[0];
    splay(root, l, 0); splay(root, r, l);
    tr[r].s[0] = 0;
    pushup(r); pushup(l);
    insert(root, y);
}

void build(int u, int l, int r) {
    L[u] = l; R[u] = r;
    insert(T[u], -INF); insert(T[u], INF);
    for (int i = l; i <= r; i++) insert(T[u], w[i]);
    if (l == r) return;
    int mid = l + r >> 1;
    build(u << 1, l, mid); build(u << 1 | 1, mid + 1, r);
}

int query(int u, int l, int r, int x) { //查询区间[l, r]里x的排名
    if (L[u] >= l && R[u] <= r) return get_k(T[u], x) - 1;
    int res = 0, mid = L[u] + R[u] >> 1;
    if (l <= mid) res += query(u << 1, l, r, x);
    if (r > mid) res += query(u << 1 | 1, l, r, x);
    return res;
}

void modify(int u, int p, int x) {
    update(T[u], w[p], x);
    if (L[u] == R[u]) return;
    int mid = L[u] + R[u] >> 1;
    if (p <= mid) modify(u << 1, p, x);
    else modify(u << 1 | 1, p, x);
}

int get_suc(int root, int x) {
    int u = root, res = INF;
    while(u) {
        if (tr[u].v > x) res = min(res, tr[u].v), u = tr[u].s[0];
        else u = tr[u].s[1];
    }
    return res;
}

int get_pre(int root, int x) {
    int u = root, res = -INF;
    while(u) {
        if (tr[u].v < x) res = max(res, tr[u].v), u = tr[u].s[1];
        else u = tr[u].s[0];
    }
    return res;
}

int query_pre(int u, int l, int r, int x) {
    if (L[u] >= l && R[u] <= r) return get_pre(T[u], x);
    int mid = L[u] + R[u] >> 1, res = -INF;
    if (l <= mid) res = max(res, query_pre(u << 1, l, r, x));
    if (r > mid) res = max(res, query_pre(u << 1 | 1, l, r, x));
    return res;
}

int query_suc(int u, int l, int r, int x) {
    if (L[u] >= l && R[u] <= r) return get_suc(T[u], x);
    int mid = L[u] + R[u] >> 1, res = INF;
    if (l <= mid) res = min(res, query_suc(u << 1, l, r, x));
    if (r > mid) res = min(res, query_suc(u << 1 | 1, l, r, x));
    return res;
}


int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i ++ ) scanf("%d", &w[i]);
    build(1, 1, n);

    while (m--) {
        int op, a, b, x;
        scanf("%d", &op);
        if (op == 1) {
            scanf("%d%d%d", &a, &b, &x);
            printf("%d\n", query(1, a, b, x) + 1);
        }
        else if (op == 2)
        {
            scanf("%d%d%d", &a, &b, &x);
            int l = 0, r = 1e8;
            while (l < r)
            {
                int mid = l + r + 1 >> 1;
                if (query(1, a, b, mid) + 1 <= x) l = mid;
                else r = mid - 1;
            }
            printf("%d\n", r);
        }
        else if (op == 3)
        {
            scanf("%d%d", &a, &x);
            modify(1, a, x);
            w[a] = x;
        }
        else if (op == 4)
        {
            scanf("%d%d%d", &a, &b, &x);
            printf("%d\n", query_pre(1, a, b, x));
        }
        else
        {
            scanf("%d%d%d", &a, &b, &x);
            printf("%d\n", query_suc(1, a, b, x));
        }
    }
    return 0;
}

最后再来看一条线段树套线段树的题目:

这题虽短,但涉及的知识点还是很多的,离散化、权值线段树、动态开点、标记持久化等,我们依次分析。

首先我们分析一下操作类型,第一个操作是区间修改,如果把每个位置看作集合,那么就是在一段连续的集合里同时加上一个数,第二个操作是区间查询,那么考虑线段树维护查询操作,内层的树如果用平衡树的话,那么对于集合的维护就很复杂,所以我们可以把外层的树设置为权值线段树。(因为权值的范围为1<<31,而点点个数最多只有50000,所以需要离散化一下。)

权值线段树和线段树维护的东西刚好相反,线段树是建立在下标上的,维护的是每个下标有哪些权值,而权值线段树是建立在权值(数轴)上的,维护的是每个权值在哪些下标里。那么对于第一个操作,线段树上是在一段连续的下标里加上一个权值c,在权值线段树里就是在权值为c的地方加上一段连续的下标。此时的区间修改就变成了单点修改,对于权值c,我们最多修改logn层权值线段树即可。

由于权值线段树里的每个节点都是一棵普通线段树,用来维护下标的修改,所以每执行一次第一个操作,我们就在内层的线段树的a-b所在的位置加上1,这个操作可以用懒标记来维护。

而对于内层的线段树来说,需要维护的操作只有add一个,所以我们可以用一个标记持久化的思想,把add作为线段树的自身属性,这样就不用每次都pushdown操作了,可以节省不少时间。

再其次由于线段树的空间占用非常之大,但是利用率并不高,并且有的节点用过一次之后就不会再被使用了,所以我们可以动态开点,用一个数组存储下所有的节点,相当于节点分配站,当我需要的时候再给他分配。

最后附上完整代码

#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <vector>

using namespace std;

typedef long long LL;

const int N = 50010, P = N * 17 * 17, M = N * 4;

int n, m;
struct Tree
{
    int l, r;
    LL sum, add;
}tr[P];
int L[M], R[M], T[M], idx;
struct Query
{
    int op, a, b, c;
}q[N];
vector<int> nums;

int get(int x)
{
    return lower_bound(nums.begin(), nums.end(), x) - nums.begin();
}

void build(int u, int l, int r)
{
    L[u] = l, R[u] = r, T[u] = ++ idx;
    if (l == r) return;
    int mid = l + r >> 1;
    build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}

int intersection(int a, int b, int c, int d)
{
    return min(b, d) - max(a, c) + 1;
}

void update(int u, int l, int r, int pl, int pr)
{
    tr[u].sum += intersection(l, r, pl, pr);
    if (l >= pl && r <= pr)
    {
        tr[u].add ++ ;
        return;
    }
    int mid = l + r >> 1;
    if (pl <= mid)
    {
        if (!tr[u].l) tr[u].l = ++ idx;
        update(tr[u].l, l, mid, pl, pr);
    }
    if (pr > mid)
    {
        if (!tr[u].r) tr[u].r = ++ idx;
        update(tr[u].r, mid + 1, r, pl, pr);
    }
}

void change(int u, int a, int b, int c)
{
    update(T[u], 1, n, a, b);
    if (L[u] == R[u]) return;
    int mid = L[u] + R[u] >> 1;
    if (c <= mid) change(u << 1, a, b, c);
    else change(u << 1 | 1, a, b, c);
}

LL get_sum(int u, int l, int r, int pl, int pr, int add)
{
    if (l >= pl && r <= pr) return tr[u].sum + (r - l + 1LL) * add;
    int mid = l + r >> 1;
    LL res = 0;
    add += tr[u].add;
    if (pl <= mid)
    {
        if (tr[u].l) res += get_sum(tr[u].l, l, mid, pl, pr, add);
        else res += intersection(l, mid, pl, pr) * add;
    }
    if (pr > mid)
    {
        if (tr[u].r) res += get_sum(tr[u].r, mid + 1, r, pl, pr, add);
        else res += intersection(mid + 1, r, pl, pr) * add;
    }
    return res;
}

int query(int u, int a, int b, int c)
{
    if (L[u] == R[u]) return R[u];
    int mid = L[u] + R[u] >> 1;
    LL k = get_sum(T[u << 1 | 1], 1, n, a, b, 0);
    if (k >= c) return query(u << 1 | 1, a, b, c);
    return query(u << 1, a, b, c - k);
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 0; i < m; i ++ )
    {
        scanf("%d%d%d%d", &q[i].op, &q[i].a, &q[i].b, &q[i].c);
        if (q[i].op == 1) nums.push_back(q[i].c);
    }
    sort(nums.begin(), nums.end());
    nums.erase(unique(nums.begin(), nums.end()), nums.end());

    build(1, 0, nums.size() - 1);

    for (int i = 0; i < m; i ++ )
    {
        int op = q[i].op, a = q[i].a, b = q[i].b, c = q[i].c;
        if (op == 1) change(1, a, b, get(c));
        else printf("%d\n", nums[query(1, a, b, c)]);
    }

    return 0;
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值