树链剖分(模板)

一个个人认为讲的比较好的B站up主:【AgOHの算法胡扯】树链剖分

前置知识:图的存储及遍历,dfs序(只存第一次被访问),线段树区间操作。

树链剖分支持的操作:(点权以及边权是可以互相转化的,这里只说点权。)

upd:2020/1/15最下方新增边剖示例。

  1. 修改及查询以x为根的子树权值
  2. 修改及查询一条链上的权值(链的定义为:树上两点的最短路径)
int index1;//时间戳;

int size1[maxn];//子树大小;
bool vis[maxn];
int depth[maxn];//深度;
int son[maxn];//重儿子;
int dfn[maxn];//dfs序;
int top1[maxn];//链的起点;
int val[maxn],W[maxn];//节点的值,节点代表的dfn对应的值;
int fa[maxn];//节点x的直接父亲;

void dfs1(int u);//预处理depth[],size1[],fa[],son[];

void dfs2(int u,int topf);//预处理dfs序,W[],top1[];

void build();//线段树建树;

void updata()//线段树的区间修改;

int myfind();//线段树的区间查询;

void updatatree(int u,int val);//给以u为根的子树加上val;

void updatachain(int u,int v,int val);//给u与v链上的所有点加上val;

int myfind(int u);//查询以u为根的子树的权值和;

int myfindchain(int u,int v);//查询u与v链上的节点权值和;

首先我们先选取一个根节点进行一次dfs1(),将一些东西预处理出来,然后再进行一次dfs2(),然后建线段树。

针对一类操作来说,以某个节点为根的子树在dfs的过程中一定是连续的,我们又处理出了子树的大小,那么就可以转化为线段树的区间操作了:

void updatatree(int x,int val){
    val%=mod;
    updata(1,index1,1,dfn[x],dfn[x]+size1[x]-1,val);
}

ll myfindtree(int x){
    return myfind(1,index1,1,dfn[x],dfn[x]+size1[x]-1);
}

然后是二类操作,首先一个引理:除根节点外,任何一个节点的父亲节点一定在一条重链(一堆连续的重儿子与其最上方的父亲组成的链)上,那么是不是可以经过若干次跳跃操作,将两个节点转移到同一条重链上,同时跳跃的时候需要注意不要跳过头了:

void updatachain(int x,int y,int val){
    val%=mod;
    while(top1[x]!=top1[y]){
        if(depth[top1[x]]<depth[top1[y]]) swap(x,y);
        //如果上面那个判断的是x与y之间的深度的话,就可能会跳过头;
        updata(1,index1,1,dfn[top1[x]],dfn[x],val);
        x=fa[top1[x]];
    }
    if(depth[x]>depth[y]) swap(x,y);
    updata(1,index1,1,dfn[x],dfn[y],val);
}

ll myfindchain(int x,int y){
    ll res=0;
    while(top1[x]!=top1[y]){
        if(depth[top1[x]]<depth[top1[y]]) swap(x,y);
        res+=myfind(1,index1,1,dfn[top1[x]],dfn[x]);
        res%=mod;
        x=fa[top1[x]];
    }
    if(depth[x]>depth[y]) swap(x,y);
    res+=myfind(1,index1,1,dfn[x],dfn[y]);
    res%=mod;
    return res;
}

 

总:

#include<bits/stdc++.h>
using namespace std;

typedef long long ll;

const int maxn=1e5+7;
int mod;

int index1;//时间戳;

int size1[maxn];//子树大小;
bool vis[maxn];
int depth[maxn];//深度;
int son[maxn];//重儿子;
int dfn[maxn];//dfs序;
int top1[maxn];//链的起点;
int val[maxn],W[maxn];//节点的值,节点代表的dfn对应的值;
int fa[maxn];//节点x的直接父亲;

struct Edge{
    int v,w,next;
}edge[maxn<<1];

int head[maxn],top;
void init(){
    top=0;
    memset(head,-1,sizeof(head));
}

void add(int u,int v,int w){
    edge[top].v=v;
    edge[top].w=w;
    edge[top].next=head[u];
    head[u]=top++;
}

void dfs1(int u){
    vis[u]=1;
    size1[u]=1;
    int maxx=-1;
    int v;
    for(int i=head[u];i!=-1;i=edge[i].next){
        v=edge[i].v;
        if(vis[v]) continue;
        fa[v]=u;
        depth[v]=depth[u]+1;
        dfs1(v);
        size1[u]+=size1[v];
        if(size1[v]>maxx){
            maxx=size1[v];
            son[u]=v;
        }
    }
}

void dfs2(int u,int t){
    dfn[u]=++index1;
    W[index1]=val[u];
    vis[u]=1;
    top1[u]=t;
    if(!son[u]) return ;
    dfs2(son[u],t);
    int v;
    for(int i=head[u];i!=-1;i=edge[i].next){
        v=edge[i].v;
        if(vis[v]||v==son[u]) continue;
        dfs2(v,v);
    }
}

ll sum[maxn<<2|1];
ll lazy[maxn<<2|1];

void pushup(int k){
    sum[k]=sum[k<<1]+sum[k<<1|1];
    sum[k]%=mod;
}

void pushdown(int l,int r,int k){
    if(lazy[k]){
        int mid=(l+r)>>1;
        lazy[k<<1]+=lazy[k];
        lazy[k<<1|1]+=lazy[k];
        sum[k<<1]+=(mid-l+1)*lazy[k];
        sum[k<<1|1]+=(r-mid)*lazy[k];
        lazy[k<<1]%=mod;
        lazy[k<<1|1]%=mod;
        sum[k<<1]%=mod;
        sum[k<<1|1]%=mod;
        lazy[k]=0;
    }
}

void build(int l,int r,int k){
    lazy[k]=0;
    if(l==r){
        sum[k]=W[l];
        sum[k]%=mod;
        return ;
    }
    int mid=(l+r)>>1;
    build(l,mid,k<<1);
    build(mid+1,r,k<<1|1);
    pushup(k);
}



void updata(int l,int r,int k,int L,int R,int val){
    if(l>=L&&r<=R){
        sum[k]+=(r-l+1)*val;
        lazy[k]+=val;
        sum[k]%=mod;
        lazy[k]%=mod;
        return ;
    }
    pushdown(l,r,k);
    int mid=(l+r)>>1;
    if(L<=mid) updata(l,mid,k<<1,L,R,val);
    if(R>mid) updata(mid+1,r,k<<1|1,L,R,val);
    pushup(k);
}

ll myfind(int l,int r,int k,int L,int R){
    if(l>=L&&r<=R) return sum[k];
    pushdown(l,r,k);
    int mid=(l+r)>>1;
    ll res=0;
    if(L<=mid) res+=myfind(l,mid,k<<1,L,R);
    if(R>mid) res+=myfind(mid+1,r,k<<1|1,L,R);
    pushup(k);
    return res%mod;
}

void updatatree(int x,int val){
    val%=mod;
    updata(1,index1,1,dfn[x],dfn[x]+size1[x]-1,val);
}

ll myfindtree(int x){
    return myfind(1,index1,1,dfn[x],dfn[x]+size1[x]-1);
}

void updatachain(int x,int y,int val){
    val%=mod;
    while(top1[x]!=top1[y]){
        if(depth[top1[x]]<depth[top1[y]]) swap(x,y);
        updata(1,index1,1,dfn[top1[x]],dfn[x],val);
        x=fa[top1[x]];
    }
    if(depth[x]>depth[y]) swap(x,y);
    updata(1,index1,1,dfn[x],dfn[y],val);
}

ll myfindchain(int x,int y){
    ll res=0;
    while(top1[x]!=top1[y]){
        if(depth[top1[x]]<depth[top1[y]]) swap(x,y);
        res+=myfind(1,index1,1,dfn[top1[x]],dfn[x]);
        res%=mod;
        x=fa[top1[x]];
    }
    if(depth[x]>depth[y]) swap(x,y);
    res+=myfind(1,index1,1,dfn[x],dfn[y]);
    res%=mod;
    return res;
}

int main(){
    int n,q,root;
    scanf("%d%d%d%d",&n,&q,&root,&mod);
    init();
    for(int i=1;i<=n;++i) scanf("%d",&val[i]);
    int u,v,w;
    for(int i=1;i<n;++i){
        scanf("%d%d",&u,&v);
        add(u,v,1);
        add(v,u,1);
    }
    depth[root]=1;
    dfs1(root);
    memset(vis,0,sizeof(vis));
    dfs2(root,root);
    build(1,index1,1);
    //for(int i=1;i<=n;++i) cout<<top1[i]<<endl;
    int id;
    while(q--){
        scanf("%d",&id);
        if(id==1){
            scanf("%d%d%d",&u,&v,&w);
            updatachain(u,v,w);
        }
        else if(id==2){
            scanf("%d%d",&u,&v);
            printf("%lld\n",myfindchain(u,v));
        }
        else if(id==3){
            scanf("%d%d",&u,&w);
            updatatree(u,w);
        }
        else if(id==4){
            scanf("%d",&u);
            printf("%lld\n",myfindtree(u));
        }
    }
    return 0;
}

复杂度:O(n*log_2(n)*log_2(n))

 

还有一道练习题,是链上每个数字开方,写过线段树的都会,就不再赘述了。

HDU-6547 Tree(注:该题中这个算法的复杂度达到了n*log_2(n)^3,其实是不可通过的,正解是括号序加线段树,只是出题人没有卡掉这种超时的做法)

#include<bits/stdc++.h>
using namespace std;

typedef long long ll;

const int maxn=1e5+7;

ll sum[maxn<<2|1];

bool f[maxn<<2|1];

int W[maxn];

void pushup(int k){
    sum[k]=sum[k<<1]+sum[k<<1|1];
    f[k]=f[k<<1]&f[k<<1|1];
}

void build(int l,int r,int k){
    if(l==r){
        sum[k]=W[l];
        if(sum[k]<=1) f[k]=1;
        return ;
    }
    int mid=(l+r)>>1;
    build(l,mid,k<<1);
    build(mid+1,r,k<<1|1);
    pushup(k);
}

void updata(int l,int r,int k,int L,int R){
    if(l==r){
        sum[k]=sqrt(sum[k]);
        if(sum[k]<=1) f[k]=1;
        return ;
    }
    if(l>=L&&r<=R){
        if(f[k]) return ;
    }
    int mid=(l+r)>>1;
    if(L<=mid) updata(l,mid,k<<1,L,R);
    if(R>mid) updata(mid+1,r,k<<1|1,L,R);
    pushup(k);
}

ll myfind(int l,int r,int k,int L,int R){
    if(l>=L&&r<=R) return sum[k];
    int mid=(l+r)>>1;
    ll res=0;
    if(L<=mid) res+=myfind(l,mid,k<<1,L,R);
    if(R>mid) res+=myfind(mid+1,r,k<<1|1,L,R);
    return res;
}

struct Edge{
    int v,next;
}edge[maxn<<1];
int head[maxn],top;
void init(){
    top=0;
    memset(head,-1,sizeof(head));
}


int num;
int son[maxn];
int size1[maxn];
int dfn[maxn];
int top1[maxn];
int fa[maxn];
int depth[maxn];
int val[maxn];

void dfs1(int u){
    size1[u]=1;
    int maxx=0;
    int v;
    for(int i=head[u];i!=-1;i=edge[i].next){
        v=edge[i].v;
        if(depth[v]) continue;
        depth[v]=depth[u]+1;
        fa[v]=u;
        dfs1(v);
        size1[u]+=size1[v];
        if(size1[v]>maxx){
            maxx=size1[v];
            son[u]=v;
        }
    }
}

void dfs2(int u,int t){
    top1[u]=t;
    dfn[u]=++num;
    W[num]=val[u];
    if(!son[u]) return ;
    dfs2(son[u],t);
    int v;
    for(int i=head[u];i!=-1;i=edge[i].next){
        v=edge[i].v;
        if(dfn[v]||v==son[u]) continue;
        dfs2(v,v);
    }
}

void updatachain(int x,int y){
    while(top1[x]!=top1[y]){
        if(depth[top1[x]]<depth[top1[y]]) swap(x,y);
        updata(1,num,1,dfn[top1[x]],dfn[x]);
        x=fa[top1[x]];
    }
    if(depth[x]>depth[y]) swap(x,y);
    updata(1,num,1,dfn[x],dfn[y]);
}

ll myfindchain(int x,int y){
    ll res=0;
    while(top1[x]!=top1[y]){
        if(depth[top1[x]]<depth[top1[y]]) swap(x,y);
        res+=myfind(1,num,1,dfn[top1[x]],dfn[x]);
        x=fa[top1[x]];
    }
    if(depth[x]>depth[y]) swap(x,y);
    res+=myfind(1,num,1,dfn[x],dfn[y]);
    return res;
}

void add(int u,int v){
    edge[top].v=v;
    edge[top].next=head[u];
    head[u]=top++;
}

int main(){
    int n,q;
    int op,x,y;
    scanf("%d%d",&n,&q);
    for(int i=1;i<=n;++i) scanf("%d",&val[i]);
    depth[1]=1;
    init();
    for(int i=1;i<n;++i){
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    dfs1(1);
    dfs2(1,1);
    build(1,num,1);
    while(q--){
        scanf("%d%d%d",&op,&x,&y);
        if(op) printf("%lld\n",myfindchain(x,y));
        else updatachain(x,y);
    }
    return 0;
}

边剖 

与点剖不同的是处于同一条重链时候的操作,
以及预处理多一个儿子与父亲的直连边的权值。
用一条边的后向点的dfs序来代表这条边的编号。
如果要修改一条边的边权的话,应该修改的是边的两端点中深度较深的点的编号。
int dep[maxn],val[maxn],dfn[maxn],num,topp[maxn],
size1[maxn],son[maxn],fa[maxn],faedge[maxn];//儿子与父亲直连边权;
bool vis[maxn];
void dfs1(int u){
    size1[u]=1;
    vis[u]=1;
    int v;
    int maxx=-1;
    for(int i=head[u];i!=-1;i=edge[i].next){
        v=edge[i].v;
        if(vis[v]) continue;
        dep[v]=dep[u]+1;
        fa[v]=u;
        faedge[v]=edge[i].w;
        dfs1(v);
        size1[u]+=size1[v];
        if(size1[v]>maxx){
            son[u]=v;
            maxx=size1[v];
        }
    }
}

void dfs2(int u,int t,int lastw){
    topp[u]=t;
    dfn[u]=++num;
    val[num]=lastw;
    if(!son[u]) return ;
    dfs2(son[u],t,faedge[son[u]]);
    int v,w;
    for(int i=head[u];i!=-1;i=edge[i].next){
        v=edge[i].v;
        w=edge[i].w;
        if(dfn[v]||v==son[u]) continue;
        dfs2(v,v,w);
    }
}

int myfindchain(int x,int y){
    int res=0;
    while(topp[x]!=topp[y]){
        if(dep[topp[x]]<dep[topp[y]]) swap(x,y);
        res+=myfindsum(1,n,1,dfn[topp[x]],dfn[x]);
        x=fa[topp[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    res+=myfindsum(1,n,1,dfn[x]+1,dfn[y]);
    return res;
}

void updatachain(int x,int y){
    while(topp[x]!=topp[y]){
        if(dep[topp[x]]<dep[topp[y]]) swap(x,y);
        updata(1,n,1,dfn[topp[x]],dfn[x]);
        x=fa[topp[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    updata(1,n,1,dfn[x]+1,dfn[y]);
}

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值