树链剖分模板整理

为什么要用树链剖分?先回顾一下两个问题:

  1. 将树从 x 到 y 结点最短路径上所有节点的值都加上 z 。
    这个问题可以用树上差分来解决
  2. 求树从x到y结点最短路径上所有节点的值之和
    这个问题可以利用 lca 来解决,distance ( x , y ) = dis ( x ) + dis ( y ) - 2 * dis ( lca )

那么如果这两个问题结合呢,这样每更新一次节点,就要 dfs 更新一次 dis ,很不方便,于是我们就可以用树链剖分来解决这类问题。

son[u]:表示u的重儿子
size[u]:表示以u为根的树的节点个数
f[u]:节点的父节点
dep[u]:节点的深度
top[u]:节点u所在链的顶端
id[u]:节点u的新编号(DFS序)
a[cnt]:在新编号(DFS序)下的当前点的点值
w[u]:题目中给出的节点u的点值

#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10;
int n, m, num, cnt, rt, cntt, k, s, mod;
int head[N], size[N], top[N], f[N], son[N], dep[N], a[N], id[N], w[N];
struct node {
    int v, nx;
} e[N];
struct tree {
    int sum, lazy;
    int len;
} t[N];

#define lson rt << 1
#define rson rt << 1 | 1

template<class T>inline void read(T &x) {
    x = 0; int f = 0; char ch = getchar();
    while (!isdigit(ch)) f |= (ch == '-'), ch = getchar();
    while (isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
    x = f ? -x : x;
    return ;
}

inline void add(int u, int v) {
    e[++num].nx = head[u], e[num].v = v, head[u] = num;
}

//fx表示x所在链的顶,fy表示y所在链的顶
int LCA(int x, int y) {
    int fx = top[x], fy = top[y];
    while (fx != fy) {
        if (dep[fx] < dep[fy]) swap(x, y), swap(fx, fy);
        x = f[fx], fx = top[x];
    }
    return dep[x] < dep[y] ? x : y;//最后返回深度小的那个
}

//u是当前节点
//fa是当前节点的父节点
void dfs1(int u, int fa) {
    size[u] = 1;//表示刚搜到u的时候以u为根的子树里只有u一个节点
    for (int i = head[u]; ~i; i = e[i].nx) {
        int v = e[i].v;//连向的节点
        if (v != fa) {//因为连的是无相边,而且是树,不能往上搜,所以我们要判断u是不是从fa搜过来,也就是判断v是不是u的子节点,也可以写作!dep[v](没有被搜到过)
            dep[v] = dep[u] + 1;//v的深度是当前节点的深度+1
            f[v] = u;//记录一下父亲
            //dis[v] = dis[u] + e[i].w;如果有边权这样加 
            dfs1(v, u);//继续往下搜,一直搜到叶节点为止
            size[u] += size[v];//往上回溯,更新以u为根的子树的size
            if (size[v] > size[son[u]]) son[u] = v;//重儿子是节点个数更多的子树,如果以u的子树中,以v为根的子树节点多,那就更新一下u的重儿子为v
        }
    }
}

//u是当前节点
//t是所在链的顶端
void dfs2(int u, int t) {
    id[u] = ++cnt;//给这个点一个新的编号
    a[cnt] = w[u];//记录这个编号下点的值
    top[u] = t;//记录u所在链的顶端为t
    if (son[u]) dfs2(son[u], t);//u的重儿子和u在同一条链里
    for (int i = head[u]; ~i; i = e[i].nx) {
        int v = e[i].v;//搜轻儿子
        if (v != f[u] && v != son[u])//判断是否是轻儿子 
            dfs2(v, v);//以轻儿子为顶的链
    }
}

inline void pushup(int rt) {
    t[rt].sum = t[lson].sum + t[rson].sum;
}

void build(int l, int r, int rt) {
    t[rt].len = r - l + 1;
    if (l == r) {
        t[rt].sum = a[l];
        return;
    }
    int m = (l + r) >> 1;
    build(l, m, lson);
    build(m + 1, r, rson);
    pushup(rt);
}

inline void pushdown(int rt) {
    if (t[rt].lazy) {
        t[lson].lazy += t[rt].lazy, t[lson].lazy %= mod;
        t[rson].lazy += t[rt].lazy, t[rson].lazy %= mod;
        t[lson].sum += t[rt].lazy * t[lson].len, t[lson].sum %= mod;
        t[rson].sum += t[rt].lazy * t[rson].len, t[rson].sum %= mod;
        t[rt].lazy = 0;
    }
}

void update(int L, int R, int c, int l, int r, int rt) {
    if (L <= l && r <= R) {
        t[rt].sum += (t[rt].len * c) % mod;
        t[rt].lazy += c;
        return ;
    }
    pushdown(rt);
    int m = (l + r) >> 1;
    if (L <= m) update(L, R, c, l, m, lson);
    if (R > m) update(L, R, c, m + 1, r, rson);
    pushup(rt);
}

int query(int L, int R, int l, int r, int rt) {
    if (L <= l && r <= R) return t[rt].sum;
    pushdown(rt);
    int m = (l + r) >> 1, ans = 0;
    if (L <= m) ans += query(L, R, l, m, lson) % mod;
    if (R > m) ans += query(L, R, m + 1, r, rson) % mod;
    return ans % mod;
}

void update_chain(int x, int y, int z) {
    int fx = top[x], fy = top[y];
    while (fx != fy) {
        if (dep[fx] < dep[fy]) swap(x, y), swap(fx, fy);
        update(id[fx], id[x], z, 1, cnt, 1);
        x = f[fx], fx = top[x];
    }
    if (id[x] > id[y]) swap(x, y);
    update(id[x], id[y], z, 1, cnt, 1);
}

int query_chain(int x, int y) {
    int ans = 0, fx = top[x], fy = top[y];
    while (fx != fy) {
        if (dep[fx] < dep[fy]) swap(x, y), swap(fx, fy);
        ans += query(id[fx], id[x], 1, cnt, 1);
        x = f[fx], fx = top[x];
    }
    if (id[x] > id[y]) swap(x, y);
    ans += query(id[x], id[y], 1, cnt, 1);
    return ans % mod;
}


int main(int argc, char const *argv[]) {
    memset(head, -1, sizeof(head));
    read(n), read(m), read(s), read(mod);
    for (int i = 1; i <= n; ++ i) read(w[i]);
    for (int i = 1, x, y; i < n; ++ i) {
        read(x), read(y);
        add(x, y), add(y, x);
    }
    f[s] = 1, dep[s] = 0;
    dfs1(s, 0);
    dfs2(s, s);
    build(1, n, 1);
    for (int i = 1, x, y, z; i <= m; ++i) {
        read(k), read(x);
        if (k == 1) {read(y), read(z); update_chain(x, y, z);}//将树从x到y结点最短路径上所有节点的值都加上z
        if (k == 2) {read(y); printf("%d\n", query_chain(x, y));}//求树从x到y结点最短路径上所有节点的值之和 
        if (k == 3) {read(z); update(id[x], id[x] + size[x] - 1, z, 1, n, 1);}//以x为根的子树所有点加z 
        if (k == 4) printf("%d\n", query(id[x], id[x] + size[x] - 1, 1, n, 1) % mod);//求以x为根节点的子树内所有节点值之和
    }
    return 0;
}

如果需要单点修改以及区间查询,区间维护最值,可以用下面的模板

#include<bits/stdc++.h>
#define ll long long 
using namespace std;

const int N = 30010, M = 2 * N;

int idx, e[M], ne[M], h[N], w[N];
int son[N], fa[N], sz[N], top[N], dep[N];
int id[N], nw[N], cnt;
int n,m;

struct node{
	int l, r;
	ll maxx, sum;
}tr[N * 4];

void add(int a, int b){ e[idx] = b, ne[idx] = h[a], h[a] = idx ++;}

void pushup(int u){ 
	tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
	tr[u].maxx = max(tr[u << 1].maxx, tr[u << 1 | 1].maxx);
}

void build(int u, int l, int r){
	tr[u] = {l,r,nw[l],nw[l]};
	if(l == r)	return ;
	int mid = l + r >> 1;
	build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
	pushup(u);
}

void update(int u, int x, int k){
	if(tr[u].l == x && tr[u].r == x){
		tr[u].maxx = tr[u].sum = k;
		return ;
	}
	int mid = tr[u].l + tr[u].r >> 1;
	if(x <= mid)	update(u << 1, x, k);
	else			update(u << 1 | 1, x, k);
	pushup(u);
}

ll query(int u, int l, int r){
	if(tr[u].l >= l && tr[u].r <= r)	return tr[u].sum;
	int mid = tr[u].l + tr[u].r >> 1;
	ll res = 0;
	if(l <= mid)	res += query(u << 1, l, r);
	if(r > mid)		res += query(u << 1 | 1, l, r);
	return res;
}

ll query_maxx(int u, int l, int r){
	if(tr[u].l >= l && tr[u].r <= r)	return tr[u].maxx;
	int mid = tr[u].l + tr[u].r >> 1;
	ll res = -1e9;
	if(l <= mid)	res = max(res, query_maxx(u << 1, l, r));
	if(r > mid)		res = max(res, query_maxx(u << 1 | 1, l, r));
	return res;
}

void dfs1(int u, int father, int depth){
	sz[u] = 1,fa[u] = father,dep[u] = depth;
	for(int i = h[u]; ~i; i = ne[i]){
		int j = e[i];
		if(j == father)	continue;
		dfs1(j, u, depth + 1);
		sz[u] += sz[j];
		if(sz[son[u]] < sz[j])	son[u] = j;
	}
}

void dfs2(int u, int t){
	id[u] = ++cnt, nw[cnt] = w[u], top[u] = t;
	if(!son[u])	return ;
	dfs2(son[u], t);
	for(int i = h[u]; ~i; i = ne[i]){
		int j = e[i];
		if(j == son[u] || j == fa[u])	continue;
		dfs2(j ,j);
	}
}

ll query_path_sum(int u, int v){
	ll res = 0;
	while(top[u] != top[v]){
		if(dep[top[u]] < dep[top[v]])	swap(u, v);
		res += query(1, id[top[u]], id[u]);
		u = fa[top[u]];
	}
	if(dep[u] < dep[v])	swap(u, v);
	res += query(1, id[v], id[u]);
	return res;
}

ll query_path_maxx(int u, int v){
	ll res = -1e9;
	while(top[u] != top[v]){
		if(dep[top[u]] < dep[top[v]])	swap(u, v);
		res = max(query_maxx(1, id[top[u]], id[u]), res);
		u = fa[top[u]];
	}
	if(dep[u] < dep[v])	swap(u, v);
	res = max(res, query_maxx(1, id[v], id[u]));
	return res;
}

void update_point(int u, int k){
	update(1, id[u], k);
}

int main(){
    while(cin >> n >> m){
        memset(h, -1, sizeof h);
        idx = 0, cnt = 0;
        for(int i = 1; i <= n; ++i){
            scanf("%d",&w[i]);
            son[i] = 0;
        }	
        for(int i = 1; i <= n - 1; ++i){
            int a, b;
            scanf("%d%d",&a,&b);
            add(a, b), add(b, a);
        }
        dfs1(1,-1,1);
        dfs2(1,1);
        build(1,1,n);

        while(m --){
            int cmd,u,v;
            scanf("%d%d%d",&cmd,&u,&v);
            if(cmd == 0){
                printf("%lld\n", query_path_maxx(u, v));
            }else if(cmd == 1){
                printf("%lld\n", query_path_sum(u, v));
            }else update_point(u, v);
        }
    }
	return 0;
}
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值