重链剖分的一些细节

文章介绍了树链剖分的概念,它用于降低树上路径问题的时间复杂度,通过重链剖分和长链剖分优化动态规划和动态树操作。同时,展示了如何使用线段树维护树上信息,并提供了代码示例,包括树的DFS遍历和LCA计算,以及如何进行树上修改和查询操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

树链剖分将树按照一定规则剖分成若干条链,在数据结构(主要是线段树和Splay)上维护,从而降低时间复杂度。
把树剖分成若干条重链,则一条路径上最多有O(log)条重链,因为每次到达一个重链的顶端,子树大小都至少翻倍。因此重链剖分用于解决树上路径问题。
长链剖分可以用于优化动态规划。
实链剖分可以用于动态树。

树链剖分求LCA

最近写的代码,感觉比原来的好看:(从树剖板子题代码上改过来的)

#include<iostream>
#include<vector>
#include<cmath>
#include<algorithm>
using namespace std;
long long M;
const int N=5e5;
vector<vector<int>> a;
int h[N+5];
int root;
int fa[N+5],dep[N+5],siz[N+5],dfn[N+5],son[N+5],top[N+5],cnt;
int dfs1(int u) {
	siz[u]=1;
	dep[u]=dep[fa[u]]+1;
	for(auto&v:a[u])
		if(v^fa[u]) {
			fa[v]=u;
			if(siz[son[u]]<dfs1(v)) son[u]=v;
			siz[u]+=siz[v];
		}
	return siz[u];
}
void dfs2(int u,int t) {
	dfn[u]=++cnt;
	top[u]=t;
	if(son[u]) dfs2(son[u],t);
	for(auto&v:a[u])
		if(v^fa[u]&&v^son[u])
			dfs2(v,v);
}
struct tree {
	int l,r;
	long long sum,add;
} t[N<<2];
void build(int u,int l,int r) {
	t[u]= {l,r,0,0};
	if(l==r) return ;
	int mid=l+r>>1;
	build(u<<1,l,mid);
	build(u<<1|1,mid+1,r);
}
void push_up(int u) {
	(t[u].sum=t[u<<1].sum+t[u<<1|1].sum)%=M;
}
void push_down(int u) {
	int l=u<<1,r=u<<1|1;
	(t[l].sum+=t[u].add*(t[l].r-t[l].l+1)%M)%=M;
	(t[r].sum+=t[u].add*(t[r].r-t[r].l+1)%M)%=M;
	(t[l].add+=t[u].add)%=M;
	(t[r].add+=t[u].add)%=M;
	t[u].add=0;
}
void push(int u,int l,int r,long long val) {
	if(l<=t[u].l&&t[u].r<=r) (t[u].sum+=val*(t[u].r-t[u].l+1)%M)%=M,(t[u].add+=val)%=M;
	else {
		push_down(u);
		int mid=t[u].l+t[u].r>>1;
		if(l<=mid) push(u<<1,l,r,val);
		if(mid<r) push(u<<1|1,l,r,val);
		push_up(u);
	}
}
long long find(int u,int l,int r) {
	if(l<=t[u].l&&t[u].r<=r) return t[u].sum;
	else {
		push_down(u);
		int mid=t[u].l+t[u].r>>1;
		long long ans=0;
		if(l<=mid) (ans+=find(u<<1,l,r))%=M;
		if(mid<r) (ans+=find(u<<1|1,l,r))%=M;
		push_up(u);
		return ans;
	}
}
int lca(int x,int y) {
	while(top[x]^top[y])
		if(dep[top[x]]<dep[top[y]])
			swap(x,y);
		else
			x=fa[top[x]];
	if(dep[x]>dep[y]) swap(x,y);
	return x;
}
void pushx(int x,int y,long long val) {
	while(top[x]^top[y])
		if(dep[top[x]]<dep[top[y]])
			swap(x,y);
		else
			push(1,dfn[top[x]],dfn[x],val),x=fa[top[x]];
	if(dep[x]>dep[y]) swap(x,y);
	push(1,dfn[x],dfn[y],val);
}
long long findx(int x,int y,long long sum=0) {
	while(top[x]^top[y])
		if(dep[top[x]]<dep[top[y]])
			swap(x,y);
		else
			(sum+=find(1,dfn[top[x]],dfn[x]))%=M,x=fa[top[x]];
	if(dep[x]>dep[y]) swap(x,y);
	(sum+=find(1,dfn[x],dfn[y]))%=M;
	return sum;
}
int main() {
	int n,m;
	cin>>n>>m>>root;
	a.resize(n+1);
	for(int i=1,u,v; i<n; i++) {
		cin>>u>>v;
		a[u].push_back(v);
		a[v].push_back(u);
	}
	dfs1(root);
	dfs2(root,root);
//	build(1,1,n);
//	for(int i=1; i<=n; i++) push(1,dfn[i],dfn[i],h[i]);
//	cout<<"***"<<endl;
//	for(int i=1;i<=n;i++) cout<<i<<' '<<find(1,dfn[i],dfn[i])<<endl;
//	cout<<"***"<<endl;
//	cout<<"***"<<endl;
//	for(int i=1;i<=n;i++) cout<<i<<' '<<dfn[i]<<endl;
//	cout<<"***"<<endl;

	while(m--) {
		int x,y;
	    cin>>x>>y;
	    cout<<lca(x,y)<<endl;
	}
	return 0;
}

老板代码:

#include<cstdio>
#include<vector>
using namespace std;
int n,m,s;
vector<vector<int>> a;
int fa[500005],size[500005],deep[500005];
int son[500005];
int dfs1(int u) {
	if(size[u]) return size[u];
	size[u]=1;
	deep[u]=deep[fa[u]]+1;
	for(auto&v:a[u]) {
		if(v==fa[u]) continue;
		fa[v]=u;
		dfs1(v);
		size[u]+=size[v];
		if(size[son[u]]<size[v])
			son[u]=v;
	}
	return size[u];
}
int top[500005];
void dfs2(int u) {
	if(son[fa[u]]==u) 
		top[u]=top[fa[u]];
	else 
		top[u]=u;
	for(auto&v:a[u])
		if(v^fa[u])
			dfs2(v);
}
int lca(int x,int y) {
	while(top[x]^top[y]) {
		if(deep[top[x]]<deep[top[y]]) swap(x,y);
		x=fa[top[x]];
	}
	return deep[x]<deep[y]?x:y;
}
int main() {
	scanf("%d%d%d",&n,&m,&s);
	for(int i=0;i<=n;i++) a.push_back({});
	for(int i=1;i<n;i++) {
		int u,v;
		scanf("%d%d",&u,&v);
		a[u].push_back(v);
		a[v].push_back(u);
	}
	dfs1(s);
	dfs2(s);
	while(m--) {
		int x,y;
		scanf("%d%d",&x,&y);
		printf("%d\n",lca(x,y));
	}
	return 0;
}

老板代码的主要细节:

  1. 正常来说dfs2需要先走重儿子,但是这道题不需要线段树维护,因此不需要重链编号连续,因此不需要这样设计。
  2. 注意lca函数中while循环的条件,top如果相同即跳过
  3. 注意lca函数中,deep[top[x]]<deep[top[y]],才交换,而不是deep[x]<deep[y]
  4. 注意lca函数中,x=fa[top[x]],而不是top[x]

树链剖分维护树上修改与查询

利用线段树可以维护一个连续区间,因此我们给节点重新编号,使得重链上的节点连续,便于维护。

新版代码:

#include<iostream>
#include<vector>
#include<cmath>
#include<algorithm>
using namespace std;
long long M;
const int N=1e5;
vector<vector<int>> a;
int h[N+5];
int root;
int fa[N+5],dep[N+5],siz[N+5],dfn[N+5],son[N+5],top[N+5],cnt;
int dfs1(int u) {
	siz[u]=1;
	dep[u]=dep[fa[u]]+1;
	for(auto&v:a[u])
		if(v^fa[u]) {
			fa[v]=u;
			if(siz[son[u]]<dfs1(v)) son[u]=v;
			siz[u]+=siz[v];
		}
	return siz[u];
}
void dfs2(int u,int t) {
	dfn[u]=++cnt;
	top[u]=t;
	if(son[u]) dfs2(son[u],t);
	for(auto&v:a[u])
		if(v^fa[u]&&v^son[u])
			dfs2(v,v);
}
struct tree {
	int l,r;
	long long sum,add;
} t[N<<2];
void build(int u,int l,int r) {
	t[u]= {l,r,0,0};
	if(l==r) return ;
	int mid=l+r>>1;
	build(u<<1,l,mid);
	build(u<<1|1,mid+1,r);
}
void push_up(int u) {
	(t[u].sum=t[u<<1].sum+t[u<<1|1].sum)%=M;
}
void push_down(int u) {
	int l=u<<1,r=u<<1|1;
	(t[l].sum+=t[u].add*(t[l].r-t[l].l+1)%M)%=M;
	(t[r].sum+=t[u].add*(t[r].r-t[r].l+1)%M)%=M;
	(t[l].add+=t[u].add)%=M;
	(t[r].add+=t[u].add)%=M;
	t[u].add=0;
}
void push(int u,int l,int r,long long val) {
	if(l<=t[u].l&&t[u].r<=r) (t[u].sum+=val*(t[u].r-t[u].l+1)%M)%=M,(t[u].add+=val)%=M;
	else {
		push_down(u);
		int mid=t[u].l+t[u].r>>1;
		if(l<=mid) push(u<<1,l,r,val);
		if(mid<r) push(u<<1|1,l,r,val);
		push_up(u);
	}
}
long long find(int u,int l,int r) {
	if(l<=t[u].l&&t[u].r<=r) return t[u].sum;
	else {
		push_down(u);
		int mid=t[u].l+t[u].r>>1;
		long long ans=0;
		if(l<=mid) (ans+=find(u<<1,l,r))%=M;
		if(mid<r) (ans+=find(u<<1|1,l,r))%=M;
		push_up(u);
		return ans;
	}
}
//int lca(int x,int y) {
//	while(top[x]^top[y])
//		if(dep[top[x]]<dep[top[y]])
//			swap(x,y);
//		else
//			x=fa[top[x]];
//	if(dep[x]>dep[y]) swap(x,y);
//	return x;
//}
void pushx(int x,int y,long long val) {
	while(top[x]^top[y])
		if(dep[top[x]]<dep[top[y]])
			swap(x,y);
		else
			push(1,dfn[top[x]],dfn[x],val),x=fa[top[x]];
	if(dep[x]>dep[y]) swap(x,y);
	push(1,dfn[x],dfn[y],val);
}
long long findx(int x,int y,long long sum=0) {
	while(top[x]^top[y])
		if(dep[top[x]]<dep[top[y]])
			swap(x,y);
		else
			(sum+=find(1,dfn[top[x]],dfn[x]))%=M,x=fa[top[x]];
	if(dep[x]>dep[y]) swap(x,y);
	(sum+=find(1,dfn[x],dfn[y]))%=M;
	return sum;
}
int main() {
	int n,m;
	cin>>n>>m>>root>>M;
	a.resize(n+1);
	for(int i=1; i<=n; i++) cin>>h[i];
	for(int i=1,u,v; i<n; i++) {
		cin>>u>>v;
		a[u].push_back(v);
		a[v].push_back(u);
	}
	dfs1(root);
	dfs2(root,root);
	build(1,1,n);
	for(int i=1; i<=n; i++) push(1,dfn[i],dfn[i],h[i]);
//	cout<<"***"<<endl;
//	for(int i=1;i<=n;i++) cout<<i<<' '<<find(1,dfn[i],dfn[i])<<endl;
//	cout<<"***"<<endl;
//	cout<<"***"<<endl;
//	for(int i=1;i<=n;i++) cout<<i<<' '<<dfn[i]<<endl;
//	cout<<"***"<<endl;

	while(m--) {
		int op;
		cin>>op;
		switch(op) {
			case 1: {
				int x,y;
				long long val;
				cin>>x>>y>>val;
				pushx(x,y,val);
//				cout<<"***"<<endl;
//				for(int i=1; i<=n; i++)
//					cout<<i<<':'<<find(1,dfn[i],dfn[i])<<endl;
//				cout<<"***"<<endl;
				break;
			}
			case 2: {
				int x,y;
				cin>>x>>y;
				cout<<findx(x,y)<<endl;
				break;
			}
			case 3: {
				int x;
				long long val;
				cin>>x>>val;
				push(1,dfn[x],dfn[x]+siz[x]-1,val);
				break;
			}
			case 4: {
				int x;
				cin>>x;
				cout<<find(1,dfn[x],dfn[x]+siz[x]-1)<<endl;
				break;
			}
		}
	}
	return 0;
}

这里主要说一下怎么跳树链:

主要就是考虑到树剖求LCA的过程中,如果x,y不在同一条重链,我们每一步都跳x,y中重链顶端深度较深的点。

如果x,y在同一条重链,则x,y中深度较浅的结点就是原本x,y的LCA。

因此我们分两部分统计答案:

  1. 如果x,y不在同一条重链中,则我们对包含x的那一部分重链(从x~top[x])统计答案,然后跳。
  2. 如果x,y在一条重链上,则我们对x,y直接统计答案,然后返回。

此外,新版代码直接用dfn[x]+siz[x]-1来维护x子树中最后一个结点的dfn。而没有采用追溯值。

老板的代码:

#include<iostream>
#include<vector>
using namespace std;
int n,m,s,p;
vector<vector<int>> a;
typedef int intx[100005];
intx b,size,son,top,id,deep,fa,low;
//b:初始权值
//id:新编号
//low:追溯值,即本颗子树的最后一个节点的编号
int cnt;
void dfs1(int u) {
	size[u]=1;
	deep[u]=deep[fa[u]]+1;
	for(auto&v:a[u]) {
		if(v==fa[u]) continue;
		fa[v]=u;
		dfs1(v);
		size[u]+=size[v];
		if(size[son[u]]<size[v])
			son[u]=v;
	}
}
void dfs2(int u) {
	id[u]=++cnt;
	if(u==son[fa[u]]) top[u]=top[fa[u]];
	else top[u]=u;
	if(son[u]) dfs2(son[u]);
	for(auto&v:a[u])
		if(v^son[u]&&v^fa[u])
			dfs2(v);
	low[u]=cnt;
}
int lca(int x,int y) {
	while(top[x]^top[y])
		if(deep[top[x]]>deep[top[y]]) x=fa[top[x]];
		else y=fa[top[y]];
	return deep[x]<deep[y]?x:y;
}
struct node {
	long long l,r;
	long long add,sum;
} t[100005<<2];
void build(int u,int l,int r) {
	t[u]= {l,r,0,0};
	if(l==r) return ;
	int mid=l+r>>1;
	build(u<<1,l,mid);
	build(u<<1|1,mid+1,r);
}
void push_up(int u) {
	t[u].sum=t[u<<1].sum+t[u<<1|1].sum;
}
void push_down(int u) {
	if(t[u].add) {
		int l=u<<1,r=u<<1|1;
		(t[l].sum+=(t[l].r-t[l].l+1)*t[u].add)%=p;
		(t[r].sum+=(t[r].r-t[r].l+1)*t[u].add)%=p;
		(t[l].add+=t[u].add)%=p;
		(t[r].add+=t[u].add)%=p;
		t[u].add=0;
	}
}
void push(int u,int l,int r,long long v) {
	if(l<=t[u].l&&t[u].r<=r) {
		(t[u].sum+=(t[u].r-t[u].l+1)*v)%=p; 
		(t[u].add+=v)%=p;
		return ;
	}
	push_down(u);
	int mid=t[u].l+t[u].r>>1;
	if(l<=mid) push(u<<1,l,r,v);
	if(mid<r) push(u<<1|1,l,r,v);
	push_up(u);
}
long long find(int u,int l,int r) {
	if(l<=t[u].l&&t[u].r<=r) return t[u].sum;
	push_down(u);
	int mid=t[u].l+t[u].r>>1;
	long long ans=0;
	if(l<=mid) (ans+=find(u<<1,l,r))%=p;
	if(mid<r) (ans+=find(u<<1|1,l,r))%=p;
	return ans;
}
void pushx(int x,int y,long long v) {
	while(top[x]^top[y]) {
		if(deep[top[x]]<deep[top[y]]) swap(x,y);
		push(1,id[top[x]],id[x],v);
		x=fa[top[x]];
	}
	if(id[x]>id[y]) swap(x,y);
	push(1,id[x],id[y],v);
}
long long findx(int x,int y) {
	long long ans=0;
	while(top[x]^top[y]) {
		if(deep[top[x]]<deep[top[y]]) swap(x,y);
		(ans+=find(1,id[top[x]],id[x]))%=p;
		x=fa[top[x]];
	}
	if(id[x]>id[y]) swap(x,y);
	(ans+=find(1,id[x],id[y]))%=p;
	return ans;
}
int main() {
	cin>>n>>m>>s>>p;
	build(1,1,n);
	for(int i=0; i<=n; i++) a.push_back({});
	for(int i=1; i<=n; i++) cin>>b[i];
	for(int i=1; i<n; i++) {
		int u,v;
		cin>>u>>v;
		a[u].push_back(v);
		a[v].push_back(u);
	}
	dfs1(s);
	dfs2(s);
	for(int i=1; i<=n; i++)
		push(1,id[i],id[i],b[i]);
//	for(int i=1;i<=n;i++)
//		cout<<i<<':'<<find(1,id[i],id[i])<<endl;
	while(m--) {
		int op;
		cin>>op;
		switch(op) {
			case 1: {
				int x,y;
				long long z;
				cin>>x>>y>>z;
				pushx(x,y,z);
//				cout<<"*"<<endl;
//				for(int i=1; i<=n; i++)
//					cout<<i<<':'<<find(1,id[i],id[i])<<endl;
				break;
			}
			case 2: {
				int x,y;
				cin>>x>>y;
				cout<<findx(x,y)<<endl;
				break;
			}
			case 3: {
				int x;
				long long z;
				cin>>x>>z;
				push(1,id[x],low[x],z);
//				cout<<"*"<<endl;
//				for(int i=1; i<=n; i++)
//					cout<<i<<':'<<find(1,id[i],id[i])<<endl;
				break;
			}
			case 4: {
				int x;
				cin>>x;
				cout<<find(1,id[x],low[x])<<endl;
				break;
			}
		}
	}
	return 0;
}
  1. 注意findx和pushx函数在while循环外面还要再执行一次
  2. 注意传入find和push里面的参数是id[x],而不是x
  3. 一条重链上的节点编号是连续的,可以快速维护,一颗子树上的节点编号也是连续的,也可以快速维护。

后记

于是皆大欢喜

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值