前言
树链剖分将树按照一定规则剖分成若干条链,在数据结构(主要是线段树和Splay)上维护,从而降低时间复杂度。
把树剖分成若干条重链,则一条路径上最多有O(log)条重链,因为每次到达一个重链的顶端,子树大小都至少翻倍。因此重链剖分用于解决树上路径问题。
长链剖分可以用于优化动态规划。
实链剖分可以用于动态树。
树链剖分求LCA
最近写的代码,感觉比原来的好看:(从树剖板子题代码上改过来的)
#include<iostream>
#include<vector>
#include<cmath>
#include<algorithm>
using namespace std;
long long M;
const int N=5e5;
vector<vector<int>> a;
int h[N+5];
int root;
int fa[N+5],dep[N+5],siz[N+5],dfn[N+5],son[N+5],top[N+5],cnt;
int dfs1(int u) {
siz[u]=1;
dep[u]=dep[fa[u]]+1;
for(auto&v:a[u])
if(v^fa[u]) {
fa[v]=u;
if(siz[son[u]]<dfs1(v)) son[u]=v;
siz[u]+=siz[v];
}
return siz[u];
}
void dfs2(int u,int t) {
dfn[u]=++cnt;
top[u]=t;
if(son[u]) dfs2(son[u],t);
for(auto&v:a[u])
if(v^fa[u]&&v^son[u])
dfs2(v,v);
}
struct tree {
int l,r;
long long sum,add;
} t[N<<2];
void build(int u,int l,int r) {
t[u]= {l,r,0,0};
if(l==r) return ;
int mid=l+r>>1;
build(u<<1,l,mid);
build(u<<1|1,mid+1,r);
}
void push_up(int u) {
(t[u].sum=t[u<<1].sum+t[u<<1|1].sum)%=M;
}
void push_down(int u) {
int l=u<<1,r=u<<1|1;
(t[l].sum+=t[u].add*(t[l].r-t[l].l+1)%M)%=M;
(t[r].sum+=t[u].add*(t[r].r-t[r].l+1)%M)%=M;
(t[l].add+=t[u].add)%=M;
(t[r].add+=t[u].add)%=M;
t[u].add=0;
}
void push(int u,int l,int r,long long val) {
if(l<=t[u].l&&t[u].r<=r) (t[u].sum+=val*(t[u].r-t[u].l+1)%M)%=M,(t[u].add+=val)%=M;
else {
push_down(u);
int mid=t[u].l+t[u].r>>1;
if(l<=mid) push(u<<1,l,r,val);
if(mid<r) push(u<<1|1,l,r,val);
push_up(u);
}
}
long long find(int u,int l,int r) {
if(l<=t[u].l&&t[u].r<=r) return t[u].sum;
else {
push_down(u);
int mid=t[u].l+t[u].r>>1;
long long ans=0;
if(l<=mid) (ans+=find(u<<1,l,r))%=M;
if(mid<r) (ans+=find(u<<1|1,l,r))%=M;
push_up(u);
return ans;
}
}
int lca(int x,int y) {
while(top[x]^top[y])
if(dep[top[x]]<dep[top[y]])
swap(x,y);
else
x=fa[top[x]];
if(dep[x]>dep[y]) swap(x,y);
return x;
}
void pushx(int x,int y,long long val) {
while(top[x]^top[y])
if(dep[top[x]]<dep[top[y]])
swap(x,y);
else
push(1,dfn[top[x]],dfn[x],val),x=fa[top[x]];
if(dep[x]>dep[y]) swap(x,y);
push(1,dfn[x],dfn[y],val);
}
long long findx(int x,int y,long long sum=0) {
while(top[x]^top[y])
if(dep[top[x]]<dep[top[y]])
swap(x,y);
else
(sum+=find(1,dfn[top[x]],dfn[x]))%=M,x=fa[top[x]];
if(dep[x]>dep[y]) swap(x,y);
(sum+=find(1,dfn[x],dfn[y]))%=M;
return sum;
}
int main() {
int n,m;
cin>>n>>m>>root;
a.resize(n+1);
for(int i=1,u,v; i<n; i++) {
cin>>u>>v;
a[u].push_back(v);
a[v].push_back(u);
}
dfs1(root);
dfs2(root,root);
// build(1,1,n);
// for(int i=1; i<=n; i++) push(1,dfn[i],dfn[i],h[i]);
// cout<<"***"<<endl;
// for(int i=1;i<=n;i++) cout<<i<<' '<<find(1,dfn[i],dfn[i])<<endl;
// cout<<"***"<<endl;
// cout<<"***"<<endl;
// for(int i=1;i<=n;i++) cout<<i<<' '<<dfn[i]<<endl;
// cout<<"***"<<endl;
while(m--) {
int x,y;
cin>>x>>y;
cout<<lca(x,y)<<endl;
}
return 0;
}
老板代码:
#include<cstdio>
#include<vector>
using namespace std;
int n,m,s;
vector<vector<int>> a;
int fa[500005],size[500005],deep[500005];
int son[500005];
int dfs1(int u) {
if(size[u]) return size[u];
size[u]=1;
deep[u]=deep[fa[u]]+1;
for(auto&v:a[u]) {
if(v==fa[u]) continue;
fa[v]=u;
dfs1(v);
size[u]+=size[v];
if(size[son[u]]<size[v])
son[u]=v;
}
return size[u];
}
int top[500005];
void dfs2(int u) {
if(son[fa[u]]==u)
top[u]=top[fa[u]];
else
top[u]=u;
for(auto&v:a[u])
if(v^fa[u])
dfs2(v);
}
int lca(int x,int y) {
while(top[x]^top[y]) {
if(deep[top[x]]<deep[top[y]]) swap(x,y);
x=fa[top[x]];
}
return deep[x]<deep[y]?x:y;
}
int main() {
scanf("%d%d%d",&n,&m,&s);
for(int i=0;i<=n;i++) a.push_back({});
for(int i=1;i<n;i++) {
int u,v;
scanf("%d%d",&u,&v);
a[u].push_back(v);
a[v].push_back(u);
}
dfs1(s);
dfs2(s);
while(m--) {
int x,y;
scanf("%d%d",&x,&y);
printf("%d\n",lca(x,y));
}
return 0;
}
老板代码的主要细节:
- 正常来说dfs2需要先走重儿子,但是这道题不需要线段树维护,因此不需要重链编号连续,因此不需要这样设计。
- 注意lca函数中while循环的条件,top如果相同即跳过
- 注意lca函数中,deep[top[x]]<deep[top[y]],才交换,而不是deep[x]<deep[y]
- 注意lca函数中,x=fa[top[x]],而不是top[x]
树链剖分维护树上修改与查询
利用线段树可以维护一个连续区间,因此我们给节点重新编号,使得重链上的节点连续,便于维护。
新版代码:
#include<iostream>
#include<vector>
#include<cmath>
#include<algorithm>
using namespace std;
long long M;
const int N=1e5;
vector<vector<int>> a;
int h[N+5];
int root;
int fa[N+5],dep[N+5],siz[N+5],dfn[N+5],son[N+5],top[N+5],cnt;
int dfs1(int u) {
siz[u]=1;
dep[u]=dep[fa[u]]+1;
for(auto&v:a[u])
if(v^fa[u]) {
fa[v]=u;
if(siz[son[u]]<dfs1(v)) son[u]=v;
siz[u]+=siz[v];
}
return siz[u];
}
void dfs2(int u,int t) {
dfn[u]=++cnt;
top[u]=t;
if(son[u]) dfs2(son[u],t);
for(auto&v:a[u])
if(v^fa[u]&&v^son[u])
dfs2(v,v);
}
struct tree {
int l,r;
long long sum,add;
} t[N<<2];
void build(int u,int l,int r) {
t[u]= {l,r,0,0};
if(l==r) return ;
int mid=l+r>>1;
build(u<<1,l,mid);
build(u<<1|1,mid+1,r);
}
void push_up(int u) {
(t[u].sum=t[u<<1].sum+t[u<<1|1].sum)%=M;
}
void push_down(int u) {
int l=u<<1,r=u<<1|1;
(t[l].sum+=t[u].add*(t[l].r-t[l].l+1)%M)%=M;
(t[r].sum+=t[u].add*(t[r].r-t[r].l+1)%M)%=M;
(t[l].add+=t[u].add)%=M;
(t[r].add+=t[u].add)%=M;
t[u].add=0;
}
void push(int u,int l,int r,long long val) {
if(l<=t[u].l&&t[u].r<=r) (t[u].sum+=val*(t[u].r-t[u].l+1)%M)%=M,(t[u].add+=val)%=M;
else {
push_down(u);
int mid=t[u].l+t[u].r>>1;
if(l<=mid) push(u<<1,l,r,val);
if(mid<r) push(u<<1|1,l,r,val);
push_up(u);
}
}
long long find(int u,int l,int r) {
if(l<=t[u].l&&t[u].r<=r) return t[u].sum;
else {
push_down(u);
int mid=t[u].l+t[u].r>>1;
long long ans=0;
if(l<=mid) (ans+=find(u<<1,l,r))%=M;
if(mid<r) (ans+=find(u<<1|1,l,r))%=M;
push_up(u);
return ans;
}
}
//int lca(int x,int y) {
// while(top[x]^top[y])
// if(dep[top[x]]<dep[top[y]])
// swap(x,y);
// else
// x=fa[top[x]];
// if(dep[x]>dep[y]) swap(x,y);
// return x;
//}
void pushx(int x,int y,long long val) {
while(top[x]^top[y])
if(dep[top[x]]<dep[top[y]])
swap(x,y);
else
push(1,dfn[top[x]],dfn[x],val),x=fa[top[x]];
if(dep[x]>dep[y]) swap(x,y);
push(1,dfn[x],dfn[y],val);
}
long long findx(int x,int y,long long sum=0) {
while(top[x]^top[y])
if(dep[top[x]]<dep[top[y]])
swap(x,y);
else
(sum+=find(1,dfn[top[x]],dfn[x]))%=M,x=fa[top[x]];
if(dep[x]>dep[y]) swap(x,y);
(sum+=find(1,dfn[x],dfn[y]))%=M;
return sum;
}
int main() {
int n,m;
cin>>n>>m>>root>>M;
a.resize(n+1);
for(int i=1; i<=n; i++) cin>>h[i];
for(int i=1,u,v; i<n; i++) {
cin>>u>>v;
a[u].push_back(v);
a[v].push_back(u);
}
dfs1(root);
dfs2(root,root);
build(1,1,n);
for(int i=1; i<=n; i++) push(1,dfn[i],dfn[i],h[i]);
// cout<<"***"<<endl;
// for(int i=1;i<=n;i++) cout<<i<<' '<<find(1,dfn[i],dfn[i])<<endl;
// cout<<"***"<<endl;
// cout<<"***"<<endl;
// for(int i=1;i<=n;i++) cout<<i<<' '<<dfn[i]<<endl;
// cout<<"***"<<endl;
while(m--) {
int op;
cin>>op;
switch(op) {
case 1: {
int x,y;
long long val;
cin>>x>>y>>val;
pushx(x,y,val);
// cout<<"***"<<endl;
// for(int i=1; i<=n; i++)
// cout<<i<<':'<<find(1,dfn[i],dfn[i])<<endl;
// cout<<"***"<<endl;
break;
}
case 2: {
int x,y;
cin>>x>>y;
cout<<findx(x,y)<<endl;
break;
}
case 3: {
int x;
long long val;
cin>>x>>val;
push(1,dfn[x],dfn[x]+siz[x]-1,val);
break;
}
case 4: {
int x;
cin>>x;
cout<<find(1,dfn[x],dfn[x]+siz[x]-1)<<endl;
break;
}
}
}
return 0;
}
这里主要说一下怎么跳树链:
主要就是考虑到树剖求LCA的过程中,如果x,y不在同一条重链,我们每一步都跳x,y中重链顶端深度较深的点。
如果x,y在同一条重链,则x,y中深度较浅的结点就是原本x,y的LCA。
因此我们分两部分统计答案:
- 如果x,y不在同一条重链中,则我们对包含x的那一部分重链(从x~top[x])统计答案,然后跳。
- 如果x,y在一条重链上,则我们对x,y直接统计答案,然后返回。
此外,新版代码直接用dfn[x]+siz[x]-1来维护x子树中最后一个结点的dfn。而没有采用追溯值。
老板的代码:
#include<iostream>
#include<vector>
using namespace std;
int n,m,s,p;
vector<vector<int>> a;
typedef int intx[100005];
intx b,size,son,top,id,deep,fa,low;
//b:初始权值
//id:新编号
//low:追溯值,即本颗子树的最后一个节点的编号
int cnt;
void dfs1(int u) {
size[u]=1;
deep[u]=deep[fa[u]]+1;
for(auto&v:a[u]) {
if(v==fa[u]) continue;
fa[v]=u;
dfs1(v);
size[u]+=size[v];
if(size[son[u]]<size[v])
son[u]=v;
}
}
void dfs2(int u) {
id[u]=++cnt;
if(u==son[fa[u]]) top[u]=top[fa[u]];
else top[u]=u;
if(son[u]) dfs2(son[u]);
for(auto&v:a[u])
if(v^son[u]&&v^fa[u])
dfs2(v);
low[u]=cnt;
}
int lca(int x,int y) {
while(top[x]^top[y])
if(deep[top[x]]>deep[top[y]]) x=fa[top[x]];
else y=fa[top[y]];
return deep[x]<deep[y]?x:y;
}
struct node {
long long l,r;
long long add,sum;
} t[100005<<2];
void build(int u,int l,int r) {
t[u]= {l,r,0,0};
if(l==r) return ;
int mid=l+r>>1;
build(u<<1,l,mid);
build(u<<1|1,mid+1,r);
}
void push_up(int u) {
t[u].sum=t[u<<1].sum+t[u<<1|1].sum;
}
void push_down(int u) {
if(t[u].add) {
int l=u<<1,r=u<<1|1;
(t[l].sum+=(t[l].r-t[l].l+1)*t[u].add)%=p;
(t[r].sum+=(t[r].r-t[r].l+1)*t[u].add)%=p;
(t[l].add+=t[u].add)%=p;
(t[r].add+=t[u].add)%=p;
t[u].add=0;
}
}
void push(int u,int l,int r,long long v) {
if(l<=t[u].l&&t[u].r<=r) {
(t[u].sum+=(t[u].r-t[u].l+1)*v)%=p;
(t[u].add+=v)%=p;
return ;
}
push_down(u);
int mid=t[u].l+t[u].r>>1;
if(l<=mid) push(u<<1,l,r,v);
if(mid<r) push(u<<1|1,l,r,v);
push_up(u);
}
long long find(int u,int l,int r) {
if(l<=t[u].l&&t[u].r<=r) return t[u].sum;
push_down(u);
int mid=t[u].l+t[u].r>>1;
long long ans=0;
if(l<=mid) (ans+=find(u<<1,l,r))%=p;
if(mid<r) (ans+=find(u<<1|1,l,r))%=p;
return ans;
}
void pushx(int x,int y,long long v) {
while(top[x]^top[y]) {
if(deep[top[x]]<deep[top[y]]) swap(x,y);
push(1,id[top[x]],id[x],v);
x=fa[top[x]];
}
if(id[x]>id[y]) swap(x,y);
push(1,id[x],id[y],v);
}
long long findx(int x,int y) {
long long ans=0;
while(top[x]^top[y]) {
if(deep[top[x]]<deep[top[y]]) swap(x,y);
(ans+=find(1,id[top[x]],id[x]))%=p;
x=fa[top[x]];
}
if(id[x]>id[y]) swap(x,y);
(ans+=find(1,id[x],id[y]))%=p;
return ans;
}
int main() {
cin>>n>>m>>s>>p;
build(1,1,n);
for(int i=0; i<=n; i++) a.push_back({});
for(int i=1; i<=n; i++) cin>>b[i];
for(int i=1; i<n; i++) {
int u,v;
cin>>u>>v;
a[u].push_back(v);
a[v].push_back(u);
}
dfs1(s);
dfs2(s);
for(int i=1; i<=n; i++)
push(1,id[i],id[i],b[i]);
// for(int i=1;i<=n;i++)
// cout<<i<<':'<<find(1,id[i],id[i])<<endl;
while(m--) {
int op;
cin>>op;
switch(op) {
case 1: {
int x,y;
long long z;
cin>>x>>y>>z;
pushx(x,y,z);
// cout<<"*"<<endl;
// for(int i=1; i<=n; i++)
// cout<<i<<':'<<find(1,id[i],id[i])<<endl;
break;
}
case 2: {
int x,y;
cin>>x>>y;
cout<<findx(x,y)<<endl;
break;
}
case 3: {
int x;
long long z;
cin>>x>>z;
push(1,id[x],low[x],z);
// cout<<"*"<<endl;
// for(int i=1; i<=n; i++)
// cout<<i<<':'<<find(1,id[i],id[i])<<endl;
break;
}
case 4: {
int x;
cin>>x;
cout<<find(1,id[x],low[x])<<endl;
break;
}
}
}
return 0;
}
- 注意findx和pushx函数在while循环外面还要再执行一次
- 注意传入find和push里面的参数是id[x],而不是x
- 一条重链上的节点编号是连续的,可以快速维护,一颗子树上的节点编号也是连续的,也可以快速维护。
后记
于是皆大欢喜