基本概念:
1.重儿子:假设节点u有n个子结点,其中以v子节点的为根子树的大小最大,那么v就是u的重儿子
2.轻儿子:除了重儿子以外的全部儿子都是轻儿子
3.轻边:结点u与轻儿子连接的边
4.重边:结点u与重儿子连接的边
5.轻链:均由轻儿子组成的一条链
6.重链:均由重儿子组成的一条链
预处理节点信息:
dep[u]:u节点的深度
fa[u]:u结点的父亲结点
son[u]:u结点的重儿子
siz[u]:以u节点为根的子树的大小
top[u]:u结点所在的链的顶点
首先,我们可以很简单地通过dfs获取一个结点的dep,fa和siz,从而也就获得了siz
实现代码如下:
#include <iostream>
using namespace std;
const int N = 2E5 + 10;
int dep[N], fa[N], son[N], siz[N], top[N];
int to[N << 1], nxt[N << 1], h[N], tot;
void dfs1(int u,int f){
siz[u] = 1;
dep[u] = dep[f] + 1;
fa[u] = f;
int max = 0;
for (int i = h[u], v; v = to[i];i=nxt[i]){
if(v==f){
continue;
}
dfs1(v, u);
siz[u] += siz[v];
if(siz[v]>max){
max = siz[v];
son[u] = v;
}
}
}
接下来,我们再用一个dfs来获取top数组
处理的方式为:
重儿子的top就等于自己u节点的top
轻儿子的top就等于轻儿子本身
实现代码如下:
void dfs2(int u,int f){
for (int i = h[u], v; v = to[i];i=nxt[i]){
if(v==f){
continue;
}
if(v==son[u]){
top[v] = top[u];
}
else{
top[v] = v;
}
dfs2(v, u);
}
}
树链剖分的应用:
1.寻找最近公共祖先(lca)
对于两个结点x和y
假设他们在同一条链上,也就是top相同,那么他们的lca就是深度比较小的一方
如果他们不在同一条链上
我们知道,他们肯定可以通过若干条链走到同一条链上
如何将一个结点从一条链转移到另一条链呢?
只需要让top[u]的深度较大的一方跳到其top[u]的父亲结点上,自然就到了另一条新链了
而且可以保证他们两个的top越来越接近,直到top相同
实现代码如下:
int lca(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]){
swap(x, y);
}
x = fa [top [x]];
}
return dep[x] < dep[y] ? x : y;
}
2.维护树上区间
我们轻重链剖分以后,每条链都是一个连续的区间
如果想要对路径(x,y)做区间修改和区间查询的操作
只需要对组成这条路径的若干条树链进行维护即可
具体操作为:
在跳跃到x与y在共同链之前
我们对区间dfn[top[x]]到dfn[x]进行修改,查询
维护区间的数据结构我们选择使用线段树
题目链接:P3384 【模板】重链剖分/树链剖分 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
实现代码如下(码了一个小时,找bug找了一个小时,还是对线段树不是很熟练):
#include <iostream>
using namespace std;
const int N = 1E5 + 10;
#define ll long long
#define ls (i << 1)
#define rs (i << 1 | 1)
#define mid (left + right >> 1)
int to[N << 1], nxt[N << 1], h[N], tot;
int dfn[N], siz[N], son[N], top[N], dep[N], fa[N], idx;
// dfn[u],轻重链u结点的dfs序(区间序号)
// v[u],u结点的点权
// a[i],区间下标i的点权
int a[N], v[N];
int n, m, s;
long long mod;
// 加边
void add(int a, int b)
{
to[++tot] = b;
nxt[tot] = h[a];
h[a] = tot;
}
// 获取每个结点的根树的size,获取深度dep,获取儿子son
void dfs1(int u, int f)
{
dep[u] = dep[f] + 1;
siz[u] = 1;
fa[u] = f;
int max = 0;
for (int i = h[u], v; v = to[i]; i = nxt[i])
{
if (v == f)
{
continue;
}
dfs1(v, u);
if (siz[v] > max)
{
max = siz[v];
son[u] = v;
}
siz[u] += siz[v];
}
}
// 获取dfn序和top
void dfs2(int u, int f)
{
dfn[u] = ++idx;
a[idx] = v[u];
if (son[u])
{
top[son[u]] = top[u];
dfs2(son[u], u);
}
for (int i = h[u], v; v = to[i]; i = nxt[i])
{
if (v == f || v == son[u])
{
continue;
}
top[v] = v;
dfs2(v, u);
}
}
//线段树
struct node
{
int l, r;
ll sum;
ll tag;
} tr[4 * N];
void pushup(int i)
{
tr[i].sum = (tr[ls].sum + tr[rs].sum) % mod;
}
void pushdown(int i)
{
if (tr[i].l != tr[i].r && tr[i].tag)
{
tr[ls].sum = (tr[ls].sum + ((tr[ls].r-tr[ls].l+1)%mod*tr[i].tag)) % mod;
tr[rs].sum = (tr[rs].sum + ((tr[rs].r-tr[rs].l+1)*tr[i].tag)) % mod;
tr[rs].tag = (tr[i].tag+tr[rs].tag)%mod;
tr[ls].tag = (tr[i].tag+tr[ls].tag)%mod;
tr[i].tag = 0;
}
}
// 建树
void build(int i, int left, int right)
{
tr[i].l = left;
tr[i].r = right;
if (left == right)
{
tr[i].sum = a[left];
return;
}
build(ls, left, mid);
build(rs, mid + 1, right);
pushup(i);
}
void add(int i, ll k, int left, int right)
{
if (tr[i].l >= left && tr[i].r <= right)
{
tr[i].sum = (((tr[i].r - tr[i].l + 1) % mod * k) % mod + tr[i].sum) % mod;
tr[i].tag = (k + tr[i].tag) % mod;
return;
}
int mmid = (tr[i].l + tr[i].r >> 1);
pushdown(i);
if (right >= mmid + 1)
{
add(rs, k, left, right);
}
if (left <= mmid)
{
add(ls, k, left, right);
}
pushup(i);
}
ll search(int i, int left, int right)
{
if (tr[i].l >= left && tr[i].r <= right)
{
return tr[i].sum;
}
pushdown(i);
ll res = 0;
int mmid = (tr[i].l + tr[i].r >> 1);
if (right >= mmid + 1)
{
res += search(rs, left, right);
}
if (left <= mmid)
{
res = (res + search(ls, left, right)) % mod;
}
return res;
}
int main()
{
cin >> n >> m >> s >> mod;
for (int i = 1; i <= n; i++)
{
cin >> v[i];
v[i] %= mod;
}
for (int i = 1, x, y; i < n; i++)
{
cin >> x >> y;
add(x, y);
add(y, x);
}
dfs1(s, 0);
top[s] = s;
dfs2(s, 0);
build(1, 1, n);
ll z;
int l, r;
ll ans;
for (int i = 1, opt, x, y; i <= m; i++)
{
cin >> opt;
if (opt == 1)
{
cin >> x >> y >> z;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]])
{
swap(x, y);
}
add(1, z, dfn[top[x]], dfn[x]);
x = fa[top[x]];
}
if (dep[x] < dep[y])
{
l = dfn[x];
r = dfn[y];
}
else
{
l = dfn[y];
r = dfn[x];
}
add(1, z, l, r);
}
else if (opt == 2)
{
cin >> x >> y;
ans = 0;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]])
{
swap(x, y);
}
ans = (ans + search(1, dfn[top[x]], dfn[x])) % mod;
x = fa[top[x]];
}
if (dep[x] < dep[y])
{
l = dfn[x];
r = dfn[y];
}
else
{
l = dfn[y];
r = dfn[x];
}
ans = (ans + search(1, l, r)) % mod;
cout << ans << endl;
}
else if (opt == 3)
{
cin >> x >> z;
add(1, z, dfn[x], dfn[x] + siz[x] - 1);
}
else if (opt == 4)
{
cin >> x;
cout << search(1, dfn[x], dfn[x] + siz[x] - 1) << endl;
}
}
return 0;
}