首先贴一个别人的动态开点模板
树链剖分详解
#include<iostream> #include<cstdio> using namespace std; const int maxn=1e5+10; struct edge{ int next,to; }e[2*maxn]; struct Node{ int sum,lazy,l,r,ls,rs; }node[2*maxn]; int rt,n,m,r,a[maxn],cnt,head[maxn],f[maxn],d[maxn],size[maxn],son[maxn],rk[maxn],top[maxn],tid[maxn]; void add_edge(int x,int y) { e[++cnt].next=head[x]; e[cnt].to=y; head[x]=cnt; } void dfs1(int u,int fa,int depth) { f[u]=fa; d[u]=depth; size[u]=1; for(int i=head[u];i;i=e[i].next) { int v=e[i].to; if(v==fa) continue; dfs1(v,u,depth+1); size[u]+=size[v]; if(size[v]>size[son[u]]) son[u]=v; } } void dfs2(int u,int t) { top[u]=t; tid[u]=++cnt; rk[cnt]=u; if(!son[u]) return; dfs2(son[u],t); for(int i=head[u];i;i=e[i].next) { int v=e[i].to; if(v!=son[u]&&v!=f[u]) dfs2(v,v); } } void pushup(int x) { node[x].sum=(node[node[x].ls].sum+node[node[x].rs].sum+node[x].lazy*(node[x].r-node[x].l+1)); } void build(int li,int ri,int cur) { if(li==ri) { node[cur].ls=node[cur].rs=-1; node[cur].l=node[cur].r=li; node[cur].sum=a[rk[li]]; return; } int mid=(li+ri)>>1; node[cur].ls=cnt++; node[cur].rs=cnt++; build(li,mid,node[cur].ls); build(mid+1,ri,node[cur].rs); node[cur].l=node[node[cur].ls].l; node[cur].r=node[node[cur].rs].r; pushup(cur); } void update(int li,int ri,int c,int cur) { if(li<=node[cur].l&&node[cur].r<=ri) { node[cur].sum+=c*(node[cur].r-node[cur].l+1); node[cur].lazy+=c; return; } int mid=(node[cur].l+node[cur].r)>>1; if(li<=mid) update(li,ri,c,node[cur].ls); if(mid<ri) update(li,ri,c,node[cur].rs); pushup(cur); } int query(int li,int ri,int cur) { if(li<=node[cur].l&&node[cur].r<=ri) return node[cur].sum; int tot=node[cur].lazy*(min(node[cur].r,ri)-max(node[cur].l,li)+1); int mid=(node[cur].l+node[cur].r)>>1; if(li<=mid) tot+=query(li,ri,node[cur].ls); if(mid<ri) tot+=query(li,ri,node[cur].rs); return tot; } int sum(int x,int y) { int ans=0,fx=top[x],fy=top[y]; while(fx!=fy) { if(d[fx]>=d[fy]) { ans+=query(tid[fx],tid[x],rt); x=f[fx]; } else { ans+=query(tid[fy],tid[y],rt); y=f[fy]; } fx=top[x]; fy=top[y]; } if(tid[x]<=tid[y]) ans+=query(tid[x],tid[y],rt); else ans+=query(tid[y],tid[x],rt); return ans; } void updates(int x,int y,int c) { int fx=top[x],fy=top[y]; while(fx!=fy) { if(d[fx]>=d[fy]) { update(tid[fx],tid[x],c,rt); x=f[fx]; } else { update(tid[fy],tid[y],c,rt); y=f[fy]; } fx=top[x]; fy=top[y]; } if(tid[x]<=tid[y]) update(tid[x],tid[y],c,rt); else update(tid[y],tid[x],c,rt); } int main() { cin>>n>>m>>r; for(int i=1;i<=n;i++) cin>>a[i]; for(int i=1;i<n;i++) { int x,y; cin>>x>>y; add_edge(x,y); add_edge(y,x); } cnt=0; dfs1(r,0,1); dfs2(r,r); cnt=0; rt=cnt++; build(1,n,rt); for(int i=1;i<=m;i++) { int op,x,y,z; cin>>op; if(op==1) { cin>>x>>y>>z; updates(x,y,z); } else if(op==2) { cin>>x>>y; cout<<sum(x,y)<<endl; } else if(op==3) { cin>>x>>z; //子树也有连续区间的性质 update(tid[x],tid[x]+size[x]-1,z,rt); } else if(op==4) { cin>>x; cout<<query(tid[x],tid[x]+size[x]-1,rt)<<endl; } } return 0; }
1. [ZJOI2008]树的统计
题目描述
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。
我们将以下面的形式来要求你对这棵树完成一些操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身
输入输出格式
输入格式:
输入文件的第一行为一个整数n,表示节点的个数。
接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。
接下来一行n个整数,第i个整数wi表示节点i的权值。
接下来1行,为一个整数q,表示操作的总数。
接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
输出格式:
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
输入输出样例
输入样例#1: 复制
4 1 2 2 3 4 1 4 2 1 3 12 QMAX 3 4 QMAX 3 3 QMAX 3 2 QMAX 2 3 QSUM 3 4 QSUM 2 1 CHANGE 1 5 QMAX 3 4 CHANGE 3 6 QMAX 3 4 QMAX 2 4 QSUM 3 4
输出样例#1: 复制
4 1 2 2 10 6 5 6 5 16
说明
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
#include<bits/stdc++.h>
#define dprintf if (debug) printf
using namespace std;
const int maxn=30050;
const int INF=0x7fffffff;
const int debug = 0;
struct edge{
int next,to;
}e[2*maxn];
struct Node{
int sum,maxx;
}T[4*maxn];
int n,m,r,a[maxn],cnt,head[maxn],f[maxn],d[maxn],size[maxn],son[maxn],rk[maxn],top[maxn],tid[maxn];
void add_edge(int x,int y)
{
e[++cnt].next=head[x];
e[cnt].to=y;
head[x]=cnt;
}
void dfs1(int u,int fa,int depth)
{
f[u]=fa;
d[u]=depth;
size[u]=1;
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa)
continue;
dfs1(v,u,depth+1);
size[u]+=size[v];
if(size[v]>size[son[u]])
son[u]=v;
}
}
void dfs2(int u,int t)
{
top[u]=t;
tid[u]=++cnt;
rk[cnt]=u;
if(!son[u])
return;
dfs2(son[u],t);
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if(v!=son[u]&&v!=f[u])
dfs2(v,v);
}
}
void build(int o,int l,int r)
{
dprintf("build %d %d %d\n", o, l, r);
if(l == r)
{
T[o].maxx = a[rk[l]];
T[o].sum = a[rk[l]];
//dprintf("cur = %d li = %d maxx = a[%d] = %d\n", cur, li, rk[li], node[cur].maxx);
return;
}
int mid=(l+r)>>1;
build(o<<1, l, mid);
build(o<<1|1, mid+1, r);
T[o].sum = T[o<<1].sum + T[o<<1|1].sum;
T[o].maxx = max(T[o<<1].maxx, T[o<<1|1].maxx);
}
void update(int o, int l, int r, int qx, int c)
{
if(l == r)
{
T[o].sum=c;
T[o].maxx=c;
return;
}
int mid=(l+r)>>1;
if(qx<=mid)
update(o<<1, l, mid, qx, c);
if(qx>=mid+1)
update(o<<1|1, mid+1, r, qx, c);
T[o].sum = T[o<<1].sum + T[o<<1|1].sum;
T[o].maxx = max(T[o<<1].maxx, T[o<<1|1].maxx);
}
int getsum(int o, int l, int r, int ql, int qr)
{
if (ql<=l && qr>=r)
return T[o].sum;
int tot=0;
int mid=(l+r)>>1;
if(ql<=mid)
tot+=getsum(o<<1, l, mid, ql, qr);
if(mid+1<=qr)
tot+=getsum(o<<1|1, mid+1, r, ql, qr);
return tot;
}
int sum(int x,int y)
{
int ans=0,fx=top[x],fy=top[y];
while(fx!=fy)
{
if(d[fx]>=d[fy])
{
ans+=getsum(1, 1, n, tid[fx],tid[x]);
x=f[fx];
}
else
{
ans+=getsum(1, 1, n, tid[fy],tid[y]);
y=f[fy];
}
fx=top[x];
fy=top[y];
}
if(tid[x]<=tid[y])
ans+=getsum(1, 1, n, tid[x],tid[y]);
else
ans+=getsum(1, 1, n, tid[y],tid[x]);
return ans;
}
int findmax(int o, int l, int r, int ql, int qr)
{
//dprintf("findmax %d %d %d maxx=%d\n", li, ri, cur, T[cur].maxx);
if (ql<=l && r<=qr)
return T[o].maxx;
int tot = -INF;
int mid=(l+r)>>1;
if (ql<=mid)
tot=max(tot, findmax(o<<1, l, mid, ql, qr));
if (qr>=mid+1)
tot=max(tot, findmax(o<<1|1, mid+1, r, ql, qr));
return tot;
}
int findmaxs(int x,int y)
{
int maxx = -INF;
int fx=top[x],fy=top[y];
while(fx!=fy)
{
if(d[fx]>=d[fy])
{
maxx = max(findmax(1, 1, n, tid[fx],tid[x]), maxx);
x=f[fx];
}
else
{
maxx = max(findmax(1, 1, n, tid[fy],tid[y]), maxx);
y=f[fy];
}
fx=top[x];
fy=top[y];
}
if(tid[x]<=tid[y])
maxx = max(findmax(1, 1, n, tid[x],tid[y]), maxx);
else
maxx = max(findmax(1, 1, n, tid[y],tid[x]), maxx);
return maxx;
}
int main()
{
scanf("%d", &n);
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d", &x, &y);
add_edge(x,y);
add_edge(y,x);
}
for(int i=1;i<=n;i++)
scanf("%d", &a[i]);
cnt=0;
r = 1;
dfs1(r,0,1);
dfs2(r,r);
cnt=0;
build(1, 1, n);
dprintf("hello\n");
scanf("%d", &m);
for(int i=1;i<=m;i++)
{
int x,y,z;
char op[10];
scanf("%s", op);
if(op[1]=='H')
{
scanf("%d%d", &x, &z);
update(1, 1, n, tid[x], z);
}
else if(op[1]=='S')
{
scanf("%d%d", &x, &y);
printf("%d\n", sum(x,y));
}
else if(op[1]=='M')
{
scanf("%d%d", &x, &y);
printf("%d\n", findmaxs(x, y));
}
}
return 0;
}
注意:读string用了一个cin,结果耗时多1倍,bzoj直接tle。cin还是要慎用,哪怕只有一句!