树链剖分(线段树的应用)

【题目描述】

原题来自:ZJOI 2008

一树上有 n 个节点,编号分别为 1 到 n,每个节点都有一个权值 w。我们将以下面的形式来要求你对这棵树完成一些操作:

1.CHANGE u t :把节点 u 权值改为 t;

2.QMAX u v :询问点 u 到点 v 路径上的节点的最大权值;

3.QSUM u v :询问点 u 到点 v 路径上的节点的权值和。

注意:从点 u 到点 v 路径上的节点包括 u 和 v 本身。

【输入】

第一行为一个数 n,表示节点个数;

接下来 n−1 行,每行两个整数 a,b,表示节点 a 与节点 b 之间有一条边相连;

接下来 n 行,每行一个整数,第 i 行的整数 wi 表示节点 i 的权值;

接下来一行,为一个整数 q ,表示操作总数;

接下来 q 行,每行一个操作,以 CHANGE u t 或 QMAX u v 或 QSUM u v的形式给出。

【输出】

对于每个 QMAX 或 QSUM 的操作,每行输出一个整数表示要求的结果。

【输入样例】
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
【输出样例】
4
1
2
2
10
6
5
6
5
16
【提示】

数据范围与提示:

对于 100% 的数据,有 1≤n≤3×104,0≤q≤2×105 。中途操作中保证每个节点的权值 w 在 −30000 至 30000 之间。

#include<bits/stdc++.h>//xyc大佬讲解的程序
#define lc (k<<1)
#define rc (k<<1|1)
#define mid ((a[k].l+a[k].r)>>1)
using namespace std;
inline void in(int &x){
	int f=1;x=0;char w=getchar();
	while(w<'0'||w>'9'){if(w=='-') f=-f;w=getchar();}
	while(w>='0'&&w<='9'){x=(x<<3)+(x<<1)+(w^48);w=getchar();}
	x*=f;
}
const int N=3e4+10;
struct node{
	int son,fa,deep,size,top,zhi,id;
}p[N];
struct tree{
	int l,r,val,maxn;
}a[N<<2];
int n,m,x,y,tot,cnt;char op[100];
int fir[N],vis[N],rk[N],nxt[N<<1],ver[N<<1];
inline void add(int x,int y){ver[++tot]=y,nxt[tot]=fir[x],fir[x]=tot;}
void dfs1(int x){
	vis[x]=1,p[x].size=1,p[x].deep=p[p[x].fa].deep+1;int maxn=0;
	for(int i=fir[x];i;i=nxt[i]){
		int y=ver[i];if(vis[y]) continue;
		p[y].fa=x,dfs1(y);p[x].size+=p[y].size;
		if(p[y].size>maxn) maxn=p[y].size,p[x].son=y;
	}
}
void dfs2(int x){
	p[x].id=++cnt,rk[cnt]=x,p[x].top=x==p[p[x].fa].son?p[p[x].fa].top:x;
	if(!p[x].son) return ;dfs2(p[x].son);
	for(int i=fir[x];i;i=nxt[i]) {int y=ver[i];if(!p[y].id) dfs2(y);}
}
void build(int k,int l,int r){
	a[k].l=l,a[k].r=r;
	if(l==r) {a[k].maxn=a[k].val=p[rk[l]].zhi;return;}
	build(lc,l,mid),build(rc,mid+1,r);
	a[k].maxn=max(a[lc].maxn,a[rc].maxn);
	a[k].val=a[lc].val+a[rc].val;
}
void add(int k,int x,int z){
	if(a[k].l==a[k].r){a[k].maxn=a[k].val=z;return;}
	if(x<=mid) add(lc,x,z);else add(rc,x,z);
	a[k].maxn=max(a[lc].maxn,a[rc].maxn);
	a[k].val=a[lc].val+a[rc].val;
}
int asksum(int k,int l,int r){
	if(a[k].l>=l&&a[k].r<=r)return a[k].val;
	int ans=0;
	if(l<=mid) ans+=asksum(lc,l,min(r,mid));
	if(r>mid) ans+=asksum(rc,max(mid+1,l),r);
	return ans;
}
int askmax(int k,int l,int r){
	if(a[k].l>=l&&a[k].r<=r)return a[k].maxn;
	int ans=-N;
	if(l<=mid) ans=max(ans,askmax(lc,l,min(mid,r)));
	if(r>mid) ans=max(ans,askmax(rc,max(mid+1,l),r));
	return ans;
}
int main(){
	in(n);for(int i=1;i<n;i++) in(x),in(y),add(x,y),add(y,x);
	for(int i=1;i<=n;i++) in(p[i].zhi);
	dfs1(1),dfs2(1),build(1,1,n),in(m);
	for(int i=1;i<=m;i++){
		scanf("%s",op),in(x),in(y);
		if(op[3]=='X'){
			int ans=-N;
			while(p[x].top!=p[y].top)
			if(p[p[x].top].deep>p[p[y].top].deep)
			ans=max(ans,askmax(1,p[p[x].top].id,p[x].id)),x=p[p[x].top].fa;
			else ans=max(ans,askmax(1,p[p[y].top].id,p[y].id)),y=p[p[y].top].fa;
			ans=max(ans,askmax(1,min(p[x].id,p[y].id),max(p[x].id,p[y].id)));
			printf("%d\n",ans);
		}
		else if(op[3]=='M'){
			int ans=0;
			while(p[x].top!=p[y].top)
			if(p[p[x].top].deep>p[p[y].top].deep)
			ans+=asksum(1,p[p[x].top].id,p[x].id),x=p[p[x].top].fa;
			else ans+=asksum(1,p[p[y].top].id,p[y].id),y=p[p[y].top].fa;
			ans+=asksum(1,min(p[x].id,p[y].id),max(p[x].id,p[y].id));
			printf("%d\n",ans);
		}
		else add(1,p[x].id,y);
	}
	return 0;
}
#include<cstdio>//另一位大佬的注释程序
#include<cstring>
using namespace std;
const int N=31000;
const int M=124000;
int n,q,k=1,first[N],summ,maxmax;
struct Edge{ int v,next;}edge[M];
int num[N];
int father[N],dep[N],size[N],son[N],top[N],seg[N],rev[N];
int maxn[M],sum[M];

int read(){
	int s=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){ if(ch=='-') f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){ s=(s<<3)+(s<<1)+ch-48;ch=getchar();}
	return f*s;
}
void addedge(int ui,int vi){
   edge[++k].v=vi;edge[k].next=first[ui];first[ui]=k;
   edge[++k].v=ui;edge[k].next=first[vi];first[vi]=k;	
}
void dfs1(int u,int fa){      //第一遍dfs 
   size[u]=1;          //u的结点数为1,包含它自己 
   father[u]=fa;       //u的父亲为fa 
   dep[u]=dep[fa]+1;   //u的深度为其父亲fa的深度+1
   for(int i=first[u];i;i=edge[i].next){   //穷举与u开头的所有边 
   	 int v1=edge[i].v;
   	 if(v1==fa) continue;
   	 dfs1(v1,u);
   	 size[u]+=size[v1];     //u包含的结点数累加儿子v1的结点
	 if(size[v1]>size[son[u]])   //如果v1的结点数>u的重儿子son[u]的结点数 
	     son[u]=v1;      //更新重儿子 
   }    	
}
void dfs2(int u,int fa){      //第二遍dfs 
	if(son[u]){   //先走重儿子,保证重路径在线段树上的位置是连续的 
		seg[son[u]]=++seg[0];    //重儿子son[u]在线段树上的编号为++seg[0] 
		top[son[u]]=top[u];      //重儿子son[u]所在的重路径的顶端结点为其父亲u所在的顶端结点 
		rev[seg[0]]=son[u];      //线段树上编号为seg[0]的结点号为son[u]
		dfs2(son[u],u);		 
	}
	for(int i=first[u];i;i=edge[i].next){
		int v1=edge[i].v;
		if(top[v1]) continue;   //如果v1已访问过, 即为重儿子或父亲,则不需要再访问
		seg[v1]=++seg[0];  
		rev[seg[0]]=v1;
		top[v1]=v1;    //若(u,v1)为轻边,则v1就是其所在重路径的顶部结点
		dfs2(v1,u); 
	}
}
int max(int x,int y){ if(x>y) return x;return y;}
void build(int k,int l,int r){  //建立线段树
   int mid=(l+r)>>1;
   if(l==r){
   	  maxn[k]=sum[k]=num[rev[l]];
   	  return;
   } 
   build(k<<1,l,mid);
   build((k<<1)+1,mid+1,r);
   sum[k]=sum[k<<1]+sum[(k<<1)+1];
   maxn[k]=max(maxn[k<<1],maxn[(k<<1)+1]);	
}
void change(int k,int l,int r,int val,int pos){  //修改pos结点,改值为val
   if(pos<l||r<pos) return;   //如果不在范围内,就退出 
   if(l==r&&l==pos){    //如果找到点,则更改 
   	  sum[k]=val;
   	  maxn[k]=val;
   	  return;
   } 
   int mid=(l+r)>>1;
   if(mid>=pos) change(k<<1,l,mid,val,pos);    
   if(mid<pos) change((k<<1)+1,mid+1,r,val,pos);
   sum[k]=sum[k<<1]+sum[(k<<1)+1];
   maxn[k]=max(maxn[k<<1],maxn[(k<<1)+1]);	   
}
void swap(int &x,int &y){ int temp=x;x=y;y=temp;}
void query(int k,int l,int r,int x,int y){
	if(y<l||x>r) return;
	if(x<=l&&r<=y){
		summ+=sum[k];
		maxmax=max(maxmax,maxn[k]);
		return;
	}
	int mid=(l+r)>>1;
	if(x<=mid) query(k<<1,l,mid,x,y);
	if(mid<y) query((k<<1)+1,mid+1,r,x,y);
}
void ask(int x,int y){  //路径询问
   int fx=top[x],fy=top[y];  //找到xy分别所在重路径的顶端结点
   while(fx!=fy){       //退出时xy在同一条重路径上 
   	   if(dep[fx]<dep[fy]){   //保证x的深度更大 
   	   	  swap(x,y);swap(fx,fy);  
   	   }
   	   query(1,1,seg[0],seg[fx],seg[x]);
   	   x=father[fx];fx=top[x];
   } 
   if(dep[x]>dep[y]) swap(x,y);    //保证x在线段树上的编号<y
   query(1,1,seg[0],seg[x],seg[y]); 
}
int main(){
    n=read();
    for(int i=1;i<n;++i) addedge(read(),read());
    for(int i=1;i<=n;++i) num[i]=read();
    dfs1(1,0);
    seg[0]=1;    //记录截止到目前线段树的编号 
	seg[1]=1;    //根结点0和1在线段树中的位置为1 
	top[1]=1;   //1所在的重路径顶部结点为1 
	rev[1]=1;   //线段树中第1个位置对应的结点还是1
	dfs2(1,0); 
	build(1,1,seg[0]);    //建立线段树 
	q=read();
	while(q--){
		char st[10];
		scanf("%s",st);
		int ui=read(),vi=read();
		if(st[0]=='C')
		   change(1,1,seg[0],vi,seg[ui]);
		else{
			summ=0;
			maxmax=-100000000;
			ask(ui,vi);
			if(st[1]=='M') printf("%d\n",maxmax);
			else printf("%d\n",summ);
		}
	}	
	return 0;	
}

一次小尝试:

#include<cstdio>
#include<cstring>
using namespace std;
const int N=30005;
struct node{
	int v,next;
}e[N<<1];
int n,y,x,q,w[N],first[N],k=0,cnt=1;
int sum[N<<2],maxn[N<<2];
char st[10];
int sum1,maxn1
int dep[N],fa[N],size[N],son[N],top[N],seg[N]/*表示图中x点在线段树中的编号,即第二次深搜时的时间戳*/,rev[N];
void add(int x,int y){
	e[++k].v=y;e[k].next=first[x];first[x]=k;
}
void dfs1(int u,int fat){
	dep[u]=dep[fa]+1;
	size[u]=1;
	for(i=first[u];i;i=e[i].next){
		int vi=e[i].v;
		if(vi==fat) continue;
		fa[vi]=u;
		dfs1(vi,u);
		size[u]+=size[vi];
		if(size[vi]>size[son[u]]) son[u]=vi;
	}
}
void dfs2(int u,int fat){
	
	if(son[u]){
		seg[son[u]]=++cnt;
		rev[cnt]=son[u];
		top[son[u]]=top[u];
		dfs2(son[u],u);
	}
	for(int i=first[u];i;i=e[i].next){//对轻儿子的操作//即该儿子top为他本身 
		int vi=e[i].v;
		if(vi==fat)continue;
		if(!top[vi]){//表示vi没有访问过 
			seg[vi]=++cnt;
			rev[cnt]=vi;
			top[vi]=vi;
			dfs2(vi,u); 
		}
	}
}
int max(int x,int y){return x > y ? x : y;}
void build(int k,int l,int r){//建立线段树,计算对应的和、最大值 
	if(l==r){
		=w[rev[l]];
		return;
	}
	int mid=(l+r)>>1;
	build(k<<1,l,mid);
	built((k<<1)+1,mid+1,r);
	sum[k]=sum[k<<1]+sum[(k<<1)+1];
	maxn[k]=max(maxn[k<<1],maxn[(k<<1)+1]);
}
void change(int k,int l,int r,int u,int t){
	if(u<l||r<u) return;
	if(l==r&&l==u){
		maxn[k]=sum[k]=t;
		return;
	}
	int mid=(l+r)<<1;
	if(u<=mid) change(k<<1,l,mid,u,t);
	else change((k<<1)+1,mid+1,r,u,t);
	sum[k]=sum[k<<1]+sum[(k<<1)+1];
	maxn[k]=max(maxn[k<<1],maxn[(k<<1)+1]);
}
void query(int k,int l,int r,int x,int y){
	if(y<l||x>r)return;
	if(l<=x&&y<=r){////////////////////////////////////////////////////////
		sum1+=sum[k];
		maxn1=max(maxn1,maxn[k]);
		return;
	}
	int mid=(l+r)<<1;
	if(x<=mid)query(k<<1,l,mid,x,y);
	if(mid+1<=y)query((k<<1)+1,mid+1,r,x,y)
}
int swap(int &x,int &y){int temp=x;x=y;y=temp;}
void ask(int x,int y){
	int fx=top[x],fy=top[y];
	while(fx!=fy){
		if(dep[fx]<dep[fy]){
			swap(x,y),swap(fa,fy);
		}
		query(1,1,n,seg[fx],seg[x]);
		x=fa[fx];
		fx=top[x];
	}
	if(dep[])
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<=n;i++){
		scanf("%d%d",&x,&y);
		ad(x,y),add(y,x);
	}
	for(int i=1;i<=n;i++)scanf("%d",&w[i]);
	dfs1(1,0);
	seg[1]=cnt;rev[1]=1;top[1]=1;
	dfs2(1,0);
	build(1,1,n);
	scanf("%d",&q);
	while(q--){
		scanf("%s%d%d",st,&x,&y);
		if(st[0]=='C')
			change(1,1,n,seg[x],y);
		else{
			sum1=0;maxn1=-3000000;
			ask(x,y);
			if(st[1]=='M')
			printf("%d",maxn1);
			else
			printf("%d",sum1);
		}
	}
	return 0;
}

又一个标程:

#include<cstdio>
#include<cstring>
#define lc (k<<1)
#define rc (k<<1|1)
using namespace std;
const int N=3e4+5;
struct node{ int vi,next;}edge[N<<1];
int n,k,x,y,cnt=0,q,max1,sum1;
int son[N],fa[N],dep[N],size[N],top[N],id[N],rev[N];
int first[N],w[N],maxn[N<<2],sum[N<<2];
int read(){   //快速读入 
	int s=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){ if(ch=='-') f=-1;ch=getchar();	}
	while(ch>='0'&&ch<='9'){ s=(s<<3)+(s<<1)+ch-48;ch=getchar();	}
	return s*f;
}
void ADD(int x,int y){ edge[++k].vi=y;edge[k].next=first[x];first[x]=k;}  //邻接表存储树 
void dfs1(int u,int father){   //第一次深搜 
	fa[u]=father;     //u的父亲是father 
	dep[u]=dep[father]+1;   //u的深度为父亲的深度加1 
	size[u]=1;              //以u为根的子树结点数初始值为1,因为自己也算一个结点 
	for(int i=first[u];i;i=edge[i].next){   //访问u的所有儿子 
		int v=edge[i].vi;
		if(v==father) continue;     //如果儿子为父亲,则不访问,因为是无向图 
		dfs1(v,u);                //深搜儿子 
		size[u]+=size[v];         //u的结点数加上其儿子v的结点数 
		if(size[v]>size[son[u]]) son[u]=v;   //如果v的结点数>u的重儿子的结点数,则更新重儿子 
	}
}
void dfs2(int u){     //第二次深搜 
	id[u]=++cnt;     //u的时间戳是cnt,即在线段树上的编号 
	rev[cnt]=u;      //在线段树上编号位cnt的,对应树上的结点号是u 
	if(son[u]){      //如果u有重儿子 
		top[son[u]]=top[u];     //则u的重儿子son[u]所在重路径的深度最小的顶点=u的重路径顶点 
		dfs2(son[u]);     //深搜重儿子 
	}
	for(int i=first[u];i;i=edge[i].next){   //访问u的所有儿子 
		int v=edge[i].vi;
		if(id[v]) continue;    //如果儿子v访问过,则不需要再次访问,比如重儿子 
		top[v]=v;              //不是重儿子的v的top值为自己 
		dfs2(v);             //深搜儿子 
	}
}
int max(int x,int y){ if(x>y) return x;return y;}
void build(int k,int l,int r){   //建立线段树 
	if(l==r){             //如果是叶子结点 
		maxn[k]=sum[k]=w[rev[l]];  
		return;
	}
	int mid=(l+r)>>1;
	build(lc,l,mid);build(rc,mid+1,r);
	maxn[k]=max(maxn[lc],maxn[rc]);
	sum[k]=sum[lc]+sum[rc];
}
void change(int k,int l,int r,int u,int val){ //将结点u的值改为val
    if(u<l||r<u) return;    //如果u不在区间[l,r]中,则退出 
    if(l==r&&l==u){       //如果找到u,则更新 
    	maxn[k]=sum[k]=val;
		return;
    }	
	int mid=(l+r)>>1;
	if(u<=mid) change(lc,l,mid,u,val);   //如果u在左子树中,则更改左子树 
	else change(rc,mid+1,r,u,val);     //否则更改右子树 
	maxn[k]=max(maxn[lc],maxn[rc]);    //更新区间[l,r]的最大值 
	sum[k]=sum[lc]+sum[rc];      //更新区间[l,r]的和 
}
void swap(int &x,int &y){int temp=x;x=y;y=temp;}
void query(int k,int l,int r,int x,int y){   //询问线段上[x,y]区间上的值 
	if(y<l||r<x) return;    //如果区间[l,r]和[x,y]无交集,则退出 
	if(x<=l&&r<=y){      //如果区间[x,y]包含[l,r],则更新max1,sum1 
		max1=max(max1,maxn[k]);
		sum1+=sum[k];
		return;
	}
	int mid=(l+r)>>1;
	if(mid>=x) query(lc,l,mid,x,y);    //如果区间[x,y]与[l,mid]有交集,则查找左子树 
	if(mid+1<=y) query(rc,mid+1,r,x,y);	 //如果区间[x,y]与[mid+1,r]有交集,则查找右子树
}
void ask(int u,int v){      //询问树上结点u,v之间的值 
	while(top[u]!=top[v]){     //如果u,v不在一个重路径上 
		if(dep[top[u]]<dep[top[v]]){ swap(u,v);	}   //保证top[u]的深度大于top[v],否则交换两者 
		query(1,1,n,id[top[u]],id[u]);    //因为后面u要跳到top[u]的父亲处,故要将u~top[u]之间的路径上的值更新,其在线段树上的编号是连续的,即id[u ~id[top[u]]
		u=fa[top[u]];     //u跳到其top[u]的父亲处 
	}
	if(dep[u]<dep[v]) swap(u,v);   //while结束时,u和v必然在同一个重路径上,if是为了保证u的深度大于v 
	query(1,1,n,id[v],id[u]);   //结点v~u在线段树上的编号是连续的,故查找id[v]~id[u] 
}

int main(){
	//freopen("count.in","r",stdin);freopen("count.out","w",stdout);
	n=read();
	for(int i=1;i<n;++i){ x=read();y=read();ADD(x,y);ADD(y,x);}
	for(int i=1;i<=n;++i) w[i]=read();
	dfs1(1,0);
	dfs2(1);
	//for(int i=1;i<=n;++i)   printf("%d:son%d fa%d dep%d size%d top%d id%d rev%d\n",i,son[i],fa[i],dep[i],size[i],top[i],id[i],rev[i] );
	build(1,1,cnt);
	q=read();
	while(q--){
		char st[10];
		scanf("%s",st);x=read();y=read();
		if(st[0]=='C')
		   change(1,1,n,id[x],y);
	    else{
	       max1=-3000000;sum1=0;
	       ask(x,y);
	       if(st[1]=='M') printf("%d\n",max1);
	       else printf("%d\n",sum1);
	    }		
	}	
	//fclose(stdin);fclose(stdout);
	return 0;
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值