为什么要用树链剖分?先回顾一下两个问题:
- 将树从 x 到 y 结点最短路径上所有节点的值都加上 z 。
这个问题可以用树上差分来解决 - 求树从x到y结点最短路径上所有节点的值之和
这个问题可以利用 lca 来解决,distance ( x , y ) = dis ( x ) + dis ( y ) - 2 * dis ( lca )
那么如果这两个问题结合呢,这样每更新一次节点,就要 dfs 更新一次 dis ,很不方便,于是我们就可以用树链剖分来解决这类问题。
son[u]:表示u的重儿子
size[u]:表示以u为根的树的节点个数
f[u]:节点的父节点
dep[u]:节点的深度
top[u]:节点u所在链的顶端
id[u]:节点u的新编号(DFS序)
a[cnt]:在新编号(DFS序)下的当前点的点值
w[u]:题目中给出的节点u的点值
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10;
int n, m, num, cnt, rt, cntt, k, s, mod;
int head[N], size[N], top[N], f[N], son[N], dep[N], a[N], id[N], w[N];
struct node {
int v, nx;
} e[N];
struct tree {
int sum, lazy;
int len;
} t[N];
#define lson rt << 1
#define rson rt << 1 | 1
template<class T>inline void read(T &x) {
x = 0; int f = 0; char ch = getchar();
while (!isdigit(ch)) f |= (ch == '-'), ch = getchar();
while (isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
x = f ? -x : x;
return ;
}
inline void add(int u, int v) {
e[++num].nx = head[u], e[num].v = v, head[u] = num;
}
//fx表示x所在链的顶,fy表示y所在链的顶
int LCA(int x, int y) {
int fx = top[x], fy = top[y];
while (fx != fy) {
if (dep[fx] < dep[fy]) swap(x, y), swap(fx, fy);
x = f[fx], fx = top[x];
}
return dep[x] < dep[y] ? x : y;//最后返回深度小的那个
}
//u是当前节点
//fa是当前节点的父节点
void dfs1(int u, int fa) {
size[u] = 1;//表示刚搜到u的时候以u为根的子树里只有u一个节点
for (int i = head[u]; ~i; i = e[i].nx) {
int v = e[i].v;//连向的节点
if (v != fa) {//因为连的是无相边,而且是树,不能往上搜,所以我们要判断u是不是从fa搜过来,也就是判断v是不是u的子节点,也可以写作!dep[v](没有被搜到过)
dep[v] = dep[u] + 1;//v的深度是当前节点的深度+1
f[v] = u;//记录一下父亲
//dis[v] = dis[u] + e[i].w;如果有边权这样加
dfs1(v, u);//继续往下搜,一直搜到叶节点为止
size[u] += size[v];//往上回溯,更新以u为根的子树的size
if (size[v] > size[son[u]]) son[u] = v;//重儿子是节点个数更多的子树,如果以u的子树中,以v为根的子树节点多,那就更新一下u的重儿子为v
}
}
}
//u是当前节点
//t是所在链的顶端
void dfs2(int u, int t) {
id[u] = ++cnt;//给这个点一个新的编号
a[cnt] = w[u];//记录这个编号下点的值
top[u] = t;//记录u所在链的顶端为t
if (son[u]) dfs2(son[u], t);//u的重儿子和u在同一条链里
for (int i = head[u]; ~i; i = e[i].nx) {
int v = e[i].v;//搜轻儿子
if (v != f[u] && v != son[u])//判断是否是轻儿子
dfs2(v, v);//以轻儿子为顶的链
}
}
inline void pushup(int rt) {
t[rt].sum = t[lson].sum + t[rson].sum;
}
void build(int l, int r, int rt) {
t[rt].len = r - l + 1;
if (l == r) {
t[rt].sum = a[l];
return;
}
int m = (l + r) >> 1;
build(l, m, lson);
build(m + 1, r, rson);
pushup(rt);
}
inline void pushdown(int rt) {
if (t[rt].lazy) {
t[lson].lazy += t[rt].lazy, t[lson].lazy %= mod;
t[rson].lazy += t[rt].lazy, t[rson].lazy %= mod;
t[lson].sum += t[rt].lazy * t[lson].len, t[lson].sum %= mod;
t[rson].sum += t[rt].lazy * t[rson].len, t[rson].sum %= mod;
t[rt].lazy = 0;
}
}
void update(int L, int R, int c, int l, int r, int rt) {
if (L <= l && r <= R) {
t[rt].sum += (t[rt].len * c) % mod;
t[rt].lazy += c;
return ;
}
pushdown(rt);
int m = (l + r) >> 1;
if (L <= m) update(L, R, c, l, m, lson);
if (R > m) update(L, R, c, m + 1, r, rson);
pushup(rt);
}
int query(int L, int R, int l, int r, int rt) {
if (L <= l && r <= R) return t[rt].sum;
pushdown(rt);
int m = (l + r) >> 1, ans = 0;
if (L <= m) ans += query(L, R, l, m, lson) % mod;
if (R > m) ans += query(L, R, m + 1, r, rson) % mod;
return ans % mod;
}
void update_chain(int x, int y, int z) {
int fx = top[x], fy = top[y];
while (fx != fy) {
if (dep[fx] < dep[fy]) swap(x, y), swap(fx, fy);
update(id[fx], id[x], z, 1, cnt, 1);
x = f[fx], fx = top[x];
}
if (id[x] > id[y]) swap(x, y);
update(id[x], id[y], z, 1, cnt, 1);
}
int query_chain(int x, int y) {
int ans = 0, fx = top[x], fy = top[y];
while (fx != fy) {
if (dep[fx] < dep[fy]) swap(x, y), swap(fx, fy);
ans += query(id[fx], id[x], 1, cnt, 1);
x = f[fx], fx = top[x];
}
if (id[x] > id[y]) swap(x, y);
ans += query(id[x], id[y], 1, cnt, 1);
return ans % mod;
}
int main(int argc, char const *argv[]) {
memset(head, -1, sizeof(head));
read(n), read(m), read(s), read(mod);
for (int i = 1; i <= n; ++ i) read(w[i]);
for (int i = 1, x, y; i < n; ++ i) {
read(x), read(y);
add(x, y), add(y, x);
}
f[s] = 1, dep[s] = 0;
dfs1(s, 0);
dfs2(s, s);
build(1, n, 1);
for (int i = 1, x, y, z; i <= m; ++i) {
read(k), read(x);
if (k == 1) {read(y), read(z); update_chain(x, y, z);}//将树从x到y结点最短路径上所有节点的值都加上z
if (k == 2) {read(y); printf("%d\n", query_chain(x, y));}//求树从x到y结点最短路径上所有节点的值之和
if (k == 3) {read(z); update(id[x], id[x] + size[x] - 1, z, 1, n, 1);}//以x为根的子树所有点加z
if (k == 4) printf("%d\n", query(id[x], id[x] + size[x] - 1, 1, n, 1) % mod);//求以x为根节点的子树内所有节点值之和
}
return 0;
}
如果需要单点修改以及区间查询,区间维护最值,可以用下面的模板
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 30010, M = 2 * N;
int idx, e[M], ne[M], h[N], w[N];
int son[N], fa[N], sz[N], top[N], dep[N];
int id[N], nw[N], cnt;
int n,m;
struct node{
int l, r;
ll maxx, sum;
}tr[N * 4];
void add(int a, int b){ e[idx] = b, ne[idx] = h[a], h[a] = idx ++;}
void pushup(int u){
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
tr[u].maxx = max(tr[u << 1].maxx, tr[u << 1 | 1].maxx);
}
void build(int u, int l, int r){
tr[u] = {l,r,nw[l],nw[l]};
if(l == r) return ;
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void update(int u, int x, int k){
if(tr[u].l == x && tr[u].r == x){
tr[u].maxx = tr[u].sum = k;
return ;
}
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid) update(u << 1, x, k);
else update(u << 1 | 1, x, k);
pushup(u);
}
ll query(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
int mid = tr[u].l + tr[u].r >> 1;
ll res = 0;
if(l <= mid) res += query(u << 1, l, r);
if(r > mid) res += query(u << 1 | 1, l, r);
return res;
}
ll query_maxx(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r) return tr[u].maxx;
int mid = tr[u].l + tr[u].r >> 1;
ll res = -1e9;
if(l <= mid) res = max(res, query_maxx(u << 1, l, r));
if(r > mid) res = max(res, query_maxx(u << 1 | 1, l, r));
return res;
}
void dfs1(int u, int father, int depth){
sz[u] = 1,fa[u] = father,dep[u] = depth;
for(int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if(j == father) continue;
dfs1(j, u, depth + 1);
sz[u] += sz[j];
if(sz[son[u]] < sz[j]) son[u] = j;
}
}
void dfs2(int u, int t){
id[u] = ++cnt, nw[cnt] = w[u], top[u] = t;
if(!son[u]) return ;
dfs2(son[u], t);
for(int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if(j == son[u] || j == fa[u]) continue;
dfs2(j ,j);
}
}
ll query_path_sum(int u, int v){
ll res = 0;
while(top[u] != top[v]){
if(dep[top[u]] < dep[top[v]]) swap(u, v);
res += query(1, id[top[u]], id[u]);
u = fa[top[u]];
}
if(dep[u] < dep[v]) swap(u, v);
res += query(1, id[v], id[u]);
return res;
}
ll query_path_maxx(int u, int v){
ll res = -1e9;
while(top[u] != top[v]){
if(dep[top[u]] < dep[top[v]]) swap(u, v);
res = max(query_maxx(1, id[top[u]], id[u]), res);
u = fa[top[u]];
}
if(dep[u] < dep[v]) swap(u, v);
res = max(res, query_maxx(1, id[v], id[u]));
return res;
}
void update_point(int u, int k){
update(1, id[u], k);
}
int main(){
while(cin >> n >> m){
memset(h, -1, sizeof h);
idx = 0, cnt = 0;
for(int i = 1; i <= n; ++i){
scanf("%d",&w[i]);
son[i] = 0;
}
for(int i = 1; i <= n - 1; ++i){
int a, b;
scanf("%d%d",&a,&b);
add(a, b), add(b, a);
}
dfs1(1,-1,1);
dfs2(1,1);
build(1,1,n);
while(m --){
int cmd,u,v;
scanf("%d%d%d",&cmd,&u,&v);
if(cmd == 0){
printf("%lld\n", query_path_maxx(u, v));
}else if(cmd == 1){
printf("%lld\n", query_path_sum(u, v));
}else update_point(u, v);
}
}
return 0;
}
987

被折叠的 条评论
为什么被折叠?



