一个个人认为讲的比较好的B站up主:【AgOHの算法胡扯】树链剖分
前置知识:图的存储及遍历,dfs序(只存第一次被访问),线段树区间操作。
树链剖分支持的操作:(点权以及边权是可以互相转化的,这里只说点权。)
upd:2020/1/15最下方新增边剖示例。
- 修改及查询以x为根的子树权值
- 修改及查询一条链上的权值(链的定义为:树上两点的最短路径)
int index1;//时间戳;
int size1[maxn];//子树大小;
bool vis[maxn];
int depth[maxn];//深度;
int son[maxn];//重儿子;
int dfn[maxn];//dfs序;
int top1[maxn];//链的起点;
int val[maxn],W[maxn];//节点的值,节点代表的dfn对应的值;
int fa[maxn];//节点x的直接父亲;
void dfs1(int u);//预处理depth[],size1[],fa[],son[];
void dfs2(int u,int topf);//预处理dfs序,W[],top1[];
void build();//线段树建树;
void updata()//线段树的区间修改;
int myfind();//线段树的区间查询;
void updatatree(int u,int val);//给以u为根的子树加上val;
void updatachain(int u,int v,int val);//给u与v链上的所有点加上val;
int myfind(int u);//查询以u为根的子树的权值和;
int myfindchain(int u,int v);//查询u与v链上的节点权值和;
首先我们先选取一个根节点进行一次dfs1(),将一些东西预处理出来,然后再进行一次dfs2(),然后建线段树。
针对一类操作来说,以某个节点为根的子树在dfs的过程中一定是连续的,我们又处理出了子树的大小,那么就可以转化为线段树的区间操作了:
void updatatree(int x,int val){
val%=mod;
updata(1,index1,1,dfn[x],dfn[x]+size1[x]-1,val);
}
ll myfindtree(int x){
return myfind(1,index1,1,dfn[x],dfn[x]+size1[x]-1);
}
然后是二类操作,首先一个引理:除根节点外,任何一个节点的父亲节点一定在一条重链(一堆连续的重儿子与其最上方的父亲组成的链)上,那么是不是可以经过若干次跳跃操作,将两个节点转移到同一条重链上,同时跳跃的时候需要注意不要跳过头了:
void updatachain(int x,int y,int val){
val%=mod;
while(top1[x]!=top1[y]){
if(depth[top1[x]]<depth[top1[y]]) swap(x,y);
//如果上面那个判断的是x与y之间的深度的话,就可能会跳过头;
updata(1,index1,1,dfn[top1[x]],dfn[x],val);
x=fa[top1[x]];
}
if(depth[x]>depth[y]) swap(x,y);
updata(1,index1,1,dfn[x],dfn[y],val);
}
ll myfindchain(int x,int y){
ll res=0;
while(top1[x]!=top1[y]){
if(depth[top1[x]]<depth[top1[y]]) swap(x,y);
res+=myfind(1,index1,1,dfn[top1[x]],dfn[x]);
res%=mod;
x=fa[top1[x]];
}
if(depth[x]>depth[y]) swap(x,y);
res+=myfind(1,index1,1,dfn[x],dfn[y]);
res%=mod;
return res;
}
总:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e5+7;
int mod;
int index1;//时间戳;
int size1[maxn];//子树大小;
bool vis[maxn];
int depth[maxn];//深度;
int son[maxn];//重儿子;
int dfn[maxn];//dfs序;
int top1[maxn];//链的起点;
int val[maxn],W[maxn];//节点的值,节点代表的dfn对应的值;
int fa[maxn];//节点x的直接父亲;
struct Edge{
int v,w,next;
}edge[maxn<<1];
int head[maxn],top;
void init(){
top=0;
memset(head,-1,sizeof(head));
}
void add(int u,int v,int w){
edge[top].v=v;
edge[top].w=w;
edge[top].next=head[u];
head[u]=top++;
}
void dfs1(int u){
vis[u]=1;
size1[u]=1;
int maxx=-1;
int v;
for(int i=head[u];i!=-1;i=edge[i].next){
v=edge[i].v;
if(vis[v]) continue;
fa[v]=u;
depth[v]=depth[u]+1;
dfs1(v);
size1[u]+=size1[v];
if(size1[v]>maxx){
maxx=size1[v];
son[u]=v;
}
}
}
void dfs2(int u,int t){
dfn[u]=++index1;
W[index1]=val[u];
vis[u]=1;
top1[u]=t;
if(!son[u]) return ;
dfs2(son[u],t);
int v;
for(int i=head[u];i!=-1;i=edge[i].next){
v=edge[i].v;
if(vis[v]||v==son[u]) continue;
dfs2(v,v);
}
}
ll sum[maxn<<2|1];
ll lazy[maxn<<2|1];
void pushup(int k){
sum[k]=sum[k<<1]+sum[k<<1|1];
sum[k]%=mod;
}
void pushdown(int l,int r,int k){
if(lazy[k]){
int mid=(l+r)>>1;
lazy[k<<1]+=lazy[k];
lazy[k<<1|1]+=lazy[k];
sum[k<<1]+=(mid-l+1)*lazy[k];
sum[k<<1|1]+=(r-mid)*lazy[k];
lazy[k<<1]%=mod;
lazy[k<<1|1]%=mod;
sum[k<<1]%=mod;
sum[k<<1|1]%=mod;
lazy[k]=0;
}
}
void build(int l,int r,int k){
lazy[k]=0;
if(l==r){
sum[k]=W[l];
sum[k]%=mod;
return ;
}
int mid=(l+r)>>1;
build(l,mid,k<<1);
build(mid+1,r,k<<1|1);
pushup(k);
}
void updata(int l,int r,int k,int L,int R,int val){
if(l>=L&&r<=R){
sum[k]+=(r-l+1)*val;
lazy[k]+=val;
sum[k]%=mod;
lazy[k]%=mod;
return ;
}
pushdown(l,r,k);
int mid=(l+r)>>1;
if(L<=mid) updata(l,mid,k<<1,L,R,val);
if(R>mid) updata(mid+1,r,k<<1|1,L,R,val);
pushup(k);
}
ll myfind(int l,int r,int k,int L,int R){
if(l>=L&&r<=R) return sum[k];
pushdown(l,r,k);
int mid=(l+r)>>1;
ll res=0;
if(L<=mid) res+=myfind(l,mid,k<<1,L,R);
if(R>mid) res+=myfind(mid+1,r,k<<1|1,L,R);
pushup(k);
return res%mod;
}
void updatatree(int x,int val){
val%=mod;
updata(1,index1,1,dfn[x],dfn[x]+size1[x]-1,val);
}
ll myfindtree(int x){
return myfind(1,index1,1,dfn[x],dfn[x]+size1[x]-1);
}
void updatachain(int x,int y,int val){
val%=mod;
while(top1[x]!=top1[y]){
if(depth[top1[x]]<depth[top1[y]]) swap(x,y);
updata(1,index1,1,dfn[top1[x]],dfn[x],val);
x=fa[top1[x]];
}
if(depth[x]>depth[y]) swap(x,y);
updata(1,index1,1,dfn[x],dfn[y],val);
}
ll myfindchain(int x,int y){
ll res=0;
while(top1[x]!=top1[y]){
if(depth[top1[x]]<depth[top1[y]]) swap(x,y);
res+=myfind(1,index1,1,dfn[top1[x]],dfn[x]);
res%=mod;
x=fa[top1[x]];
}
if(depth[x]>depth[y]) swap(x,y);
res+=myfind(1,index1,1,dfn[x],dfn[y]);
res%=mod;
return res;
}
int main(){
int n,q,root;
scanf("%d%d%d%d",&n,&q,&root,&mod);
init();
for(int i=1;i<=n;++i) scanf("%d",&val[i]);
int u,v,w;
for(int i=1;i<n;++i){
scanf("%d%d",&u,&v);
add(u,v,1);
add(v,u,1);
}
depth[root]=1;
dfs1(root);
memset(vis,0,sizeof(vis));
dfs2(root,root);
build(1,index1,1);
//for(int i=1;i<=n;++i) cout<<top1[i]<<endl;
int id;
while(q--){
scanf("%d",&id);
if(id==1){
scanf("%d%d%d",&u,&v,&w);
updatachain(u,v,w);
}
else if(id==2){
scanf("%d%d",&u,&v);
printf("%lld\n",myfindchain(u,v));
}
else if(id==3){
scanf("%d%d",&u,&w);
updatatree(u,w);
}
else if(id==4){
scanf("%d",&u);
printf("%lld\n",myfindtree(u));
}
}
return 0;
}
复杂度:
还有一道练习题,是链上每个数字开方,写过线段树的都会,就不再赘述了。
HDU-6547 Tree(注:该题中这个算法的复杂度达到了,其实是不可通过的,正解是括号序加线段树,只是出题人没有卡掉这种超时的做法)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e5+7;
ll sum[maxn<<2|1];
bool f[maxn<<2|1];
int W[maxn];
void pushup(int k){
sum[k]=sum[k<<1]+sum[k<<1|1];
f[k]=f[k<<1]&f[k<<1|1];
}
void build(int l,int r,int k){
if(l==r){
sum[k]=W[l];
if(sum[k]<=1) f[k]=1;
return ;
}
int mid=(l+r)>>1;
build(l,mid,k<<1);
build(mid+1,r,k<<1|1);
pushup(k);
}
void updata(int l,int r,int k,int L,int R){
if(l==r){
sum[k]=sqrt(sum[k]);
if(sum[k]<=1) f[k]=1;
return ;
}
if(l>=L&&r<=R){
if(f[k]) return ;
}
int mid=(l+r)>>1;
if(L<=mid) updata(l,mid,k<<1,L,R);
if(R>mid) updata(mid+1,r,k<<1|1,L,R);
pushup(k);
}
ll myfind(int l,int r,int k,int L,int R){
if(l>=L&&r<=R) return sum[k];
int mid=(l+r)>>1;
ll res=0;
if(L<=mid) res+=myfind(l,mid,k<<1,L,R);
if(R>mid) res+=myfind(mid+1,r,k<<1|1,L,R);
return res;
}
struct Edge{
int v,next;
}edge[maxn<<1];
int head[maxn],top;
void init(){
top=0;
memset(head,-1,sizeof(head));
}
int num;
int son[maxn];
int size1[maxn];
int dfn[maxn];
int top1[maxn];
int fa[maxn];
int depth[maxn];
int val[maxn];
void dfs1(int u){
size1[u]=1;
int maxx=0;
int v;
for(int i=head[u];i!=-1;i=edge[i].next){
v=edge[i].v;
if(depth[v]) continue;
depth[v]=depth[u]+1;
fa[v]=u;
dfs1(v);
size1[u]+=size1[v];
if(size1[v]>maxx){
maxx=size1[v];
son[u]=v;
}
}
}
void dfs2(int u,int t){
top1[u]=t;
dfn[u]=++num;
W[num]=val[u];
if(!son[u]) return ;
dfs2(son[u],t);
int v;
for(int i=head[u];i!=-1;i=edge[i].next){
v=edge[i].v;
if(dfn[v]||v==son[u]) continue;
dfs2(v,v);
}
}
void updatachain(int x,int y){
while(top1[x]!=top1[y]){
if(depth[top1[x]]<depth[top1[y]]) swap(x,y);
updata(1,num,1,dfn[top1[x]],dfn[x]);
x=fa[top1[x]];
}
if(depth[x]>depth[y]) swap(x,y);
updata(1,num,1,dfn[x],dfn[y]);
}
ll myfindchain(int x,int y){
ll res=0;
while(top1[x]!=top1[y]){
if(depth[top1[x]]<depth[top1[y]]) swap(x,y);
res+=myfind(1,num,1,dfn[top1[x]],dfn[x]);
x=fa[top1[x]];
}
if(depth[x]>depth[y]) swap(x,y);
res+=myfind(1,num,1,dfn[x],dfn[y]);
return res;
}
void add(int u,int v){
edge[top].v=v;
edge[top].next=head[u];
head[u]=top++;
}
int main(){
int n,q;
int op,x,y;
scanf("%d%d",&n,&q);
for(int i=1;i<=n;++i) scanf("%d",&val[i]);
depth[1]=1;
init();
for(int i=1;i<n;++i){
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs1(1);
dfs2(1,1);
build(1,num,1);
while(q--){
scanf("%d%d%d",&op,&x,&y);
if(op) printf("%lld\n",myfindchain(x,y));
else updatachain(x,y);
}
return 0;
}
边剖
与点剖不同的是处于同一条重链时候的操作,
以及预处理多一个儿子与父亲的直连边的权值。
用一条边的后向点的dfs序来代表这条边的编号。
如果要修改一条边的边权的话,应该修改的是边的两端点中深度较深的点的编号。
int dep[maxn],val[maxn],dfn[maxn],num,topp[maxn],
size1[maxn],son[maxn],fa[maxn],faedge[maxn];//儿子与父亲直连边权;
bool vis[maxn];
void dfs1(int u){
size1[u]=1;
vis[u]=1;
int v;
int maxx=-1;
for(int i=head[u];i!=-1;i=edge[i].next){
v=edge[i].v;
if(vis[v]) continue;
dep[v]=dep[u]+1;
fa[v]=u;
faedge[v]=edge[i].w;
dfs1(v);
size1[u]+=size1[v];
if(size1[v]>maxx){
son[u]=v;
maxx=size1[v];
}
}
}
void dfs2(int u,int t,int lastw){
topp[u]=t;
dfn[u]=++num;
val[num]=lastw;
if(!son[u]) return ;
dfs2(son[u],t,faedge[son[u]]);
int v,w;
for(int i=head[u];i!=-1;i=edge[i].next){
v=edge[i].v;
w=edge[i].w;
if(dfn[v]||v==son[u]) continue;
dfs2(v,v,w);
}
}
int myfindchain(int x,int y){
int res=0;
while(topp[x]!=topp[y]){
if(dep[topp[x]]<dep[topp[y]]) swap(x,y);
res+=myfindsum(1,n,1,dfn[topp[x]],dfn[x]);
x=fa[topp[x]];
}
if(dep[x]>dep[y]) swap(x,y);
res+=myfindsum(1,n,1,dfn[x]+1,dfn[y]);
return res;
}
void updatachain(int x,int y){
while(topp[x]!=topp[y]){
if(dep[topp[x]]<dep[topp[y]]) swap(x,y);
updata(1,n,1,dfn[topp[x]],dfn[x]);
x=fa[topp[x]];
}
if(dep[x]>dep[y]) swap(x,y);
updata(1,n,1,dfn[x]+1,dfn[y]);
}