【题目描述】
HYSBZ - 2157树链剖分
【题目分析】
这道题给出的是边权而不是点权,但是我们分析这个树就会发现每个节点都只有一个父亲,也就是每条边的边权都可以存放在儿子节点上,然后在遍历路径的时候我们在从前往后遍历,但是注意最后一条链的链首不要算在内(因为我们只算边权,链首存储的值不在路径上)
接下来就是无穷无尽的代码了,我自己写的一直一直wa,我看了一天了也没有看出哪里有问题,借鉴大佬的代码:
【AC代码】
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
using namespace std;
const int maxn = 120000;
#define ll long long
#define ls now<<1
#define rs now<<1|1
inline int read()
{
int x = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = x * 10 + ch - 48; ch = getchar(); }
return x * f;
}
struct node
{
ll f, t, v;
}e[maxn << 1];
struct tree
{
ll l, r, sum, maxx, minn, add, mul , num;
}tre[maxn << 2];
ll n, m, tot, cnt;
ll head[maxn], nxt[maxn << 1], used[maxn], dep[maxn], fa[maxn], sz[maxn], w1[maxn];
ll w2[maxn] , son[maxn] , top[maxn] , num[maxn];
inline void buildnode(ll a, ll b , ll c)
{
tot++;
e[tot].f = a;
e[tot].t = b;
e[tot].v = c;
nxt[tot] = head[a];
head[a] = tot;
}
inline void pushup(ll now)
{
tre[now].sum = tre[ls].sum + tre[rs].sum;
tre[now].maxx = max(tre[ls].maxx, tre[rs].maxx);
tre[now].minn = min(tre[ls].minn, tre[rs].minn);
}
inline void build(ll now, ll l, ll r)
{
tre[now].l = l;
tre[now].r = r;
tre[now].num = r - l + 1;
if (l == r)
{
tre[now].sum = w2[l];
tre[now].maxx = w2[l];
tre[now].minn = w2[l];
return;
}
ll mid = (l + r) >> 1;
build(ls, l, mid);
build(rs, mid + 1 , r);
pushup(now);
}
inline void pushdown(ll now)
{
if (tre[now].mul)
{
tre[ls].sum *= -1;
tre[rs].sum *= -1;
ll t = tre[ls].maxx;
tre[ls].maxx = -tre[ls].minn;
tre[ls].minn = -t;
t = tre[rs].maxx;
tre[rs].maxx = -tre[rs].minn;
tre[rs].minn = -t;
}
tre[ls].mul = (tre[ls].mul + tre[now].mul) % 2;
tre[rs].mul = (tre[rs].mul + tre[now].mul) % 2;
tre[now].mul = 0;
}
inline void update1(ll now, ll x , ll p)
{
if (tre[now].l == tre[now].r)
{
tre[now].sum = p;
tre[now].maxx = p;
tre[now].minn = p;
return;
}
pushdown(now);
ll mid = (tre[now].l + tre[now].r) >> 1;
if (x <= mid) update1(ls, x, p);
if (x > mid) update1(rs, x, p);
pushup(now);
}
inline void update2(ll now, ll l, ll r)
{
if (l <= tre[now].l && tre[now].r <= r)
{
tre[now].sum = -tre[now].sum;
tre[now].mul = (tre[now].mul + 1) % 2;
ll t = tre[now].maxx;
tre[now].maxx = -tre[now].minn;
tre[now].minn = -t;
return;
}
pushdown(now);
ll mid = (tre[now].l + tre[now].r) >> 1;
if (l <= mid) update2(ls, l, r);
if (r > mid) update2(rs, l, r);
pushup(now);
}
inline void dfs1(ll x, ll fat)
{
dep[x] = dep[fat] + 1;
fa[x] = fat;
sz[x] = 1;
int k = -1;
for (int i = head[x]; i; i = nxt[i])
{
int u = e[i].t;
if (u == fa[x]) continue;
w1[u] = e[i].v;
dfs1(u, x);
sz[x] += sz[u];
if (sz[u] > k) k = sz[u], son[x] = u;
}
}
inline void dfs2(ll x , ll topx)
{
num[x] = ++cnt;
w2[cnt] = w1[x];
top[x] = topx;
if (!son[x]) return;
dfs2(son[x], topx);
for (int i = head[x]; i; i = nxt[i])
{
int u = e[i].t;
if (u == fa[x] || u == son[x]) continue;
dfs2(u, u);
}
}
inline int querys(ll now, ll l, ll r)
{
if (l <= tre[now].l && tre[now].r <= r)
return tre[now].sum;
pushdown(now);
ll mid = (tre[now].l + tre[now].r) >> 1, ans = 0;
if (l <= mid) ans += querys(ls, l, r);
if (r > mid) ans += querys(rs, l, r);
return ans;
}
inline ll query_max(ll now, ll l , ll r)
{
if (l <= tre[now].l &&tre[now].r <= r)
return tre[now].maxx;
pushdown(now);
ll mid = (tre[now].l + tre[now].r) >> 1, ans = -1e9;
if (l <= mid) ans = max(ans, query_max(ls, l, r));
if (r > mid) ans = max(ans, query_max(rs, l, r));
return ans;
}
inline ll query_min(ll now, ll l , ll r)
{
if (l <= tre[now].l &&tre[now].r <= r)
return tre[now].minn;
pushdown(now);
ll mid = (tre[now].l + tre[now].r) >> 1, ans = 1e9;
if (l <= mid) ans = min(ans, query_min(ls, l, r));
if (r > mid) ans = min(ans, query_min(rs, l, r));
return ans;
}
inline ll qline(ll x, ll y)
{
ll ans = 0;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]]) swap(x , y);
ans += querys(1, num[top[x]], num[x]);
x = fa[top[x]];
}
if (dep[x] > dep[y]) swap(x, y);
ans += querys(1, num[x] + 1, num[y]);
return ans;
}
inline void cline(ll x, ll y)
{
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]]) swap(x , y);
update2(1, num[top[x]], num[x]);
x = fa[top[x]];
}
if (dep[x] > dep[y]) swap(x, y);
update2(1, num[x] + 1, num[y]);
}
inline ll qpmax(ll x, ll y)
{
ll ans = -1e9;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]]) swap(x , y);
ans = max(ans, query_max(1, num[top[x]], num[x]));
x = fa[top[x]];
}
if (dep[x] > dep[y]) swap(x, y);
ans = max(ans, query_max(1, num[x] + 1, num[y]));
return ans;
}
inline ll qpmin(ll x, ll y)
{
ll ans = 1e9;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]]) swap(x , y);
ans = min(ans, query_min(1, num[top[x]], num[x]));
x = fa[top[x]];
}
if (dep[x] > dep[y]) swap(x, y);
ans = min(ans, query_min(1, num[x] + 1, num[y]));
return ans;
}
string s1 = "C", s2 = "N", s3 = "SUM", s4 = "MAX", s5 = "MIN";
int main()
{
//freopen("1.in","r",stdin);
n = read();
for (int i = 1; i <= n - 1; i++)
{
int a, b, c;
a = read(); b = read(); c = read();
buildnode(a + 1, b + 1, c);
buildnode(b + 1, a + 1, c);
}
dfs1(1 , 0);
dfs2(1 , 1);
build(1 , 1 , n);
m = read();
while (m--)
{
string s;
int a, b, k;
cin >> s >> a >> b;
a += 1 , b += 1;
if (s == s1)
{
a -= 1 , b -= 1;
if (dep[e[a * 2 - 1].f] > dep[e[a * 2 - 1].t])
k = e[a * 2 - 1].f;
else k = e[a * 2 - 1].t;
update1(1, num[k], b);
}
if (s == s2) cline(a, b);
if (s == s3) printf("%lld\n", qline(a, b));
if (s == s4) printf("%lld\n", qpmax(a, b));
if (s == s5) printf("%lld\n", qpmin(a, b));
}
return 0;
}
我自己写的也贴上,方便日后检查
#include<cstdio>
#include<cstring>
#include<cmath>
#include<climits>
#include<cstdlib>
#include<algorithm>
#include<queue>
#include<vector>
#include<set>
using namespace std;
typedef long long ll;
const int MAXN=120005;
int fa[MAXN],A[MAXN],val[MAXN],pos[MAXN];
int siz[MAXN],son[MAXN],h[MAXN],top[MAXN];
int cnt=0,n,m;
int Sum[MAXN<<2],Max[MAXN<<2],Min[MAXN<<2],lazy[MAXN<<2];
struct node
{
int u,v,w;
}e[MAXN<<1];
int head[MAXN<<1],nxt[MAXN<<1];
int tot=0;
void AddEdge(int u,int v,int w)
{
tot++;
e[tot].u=u; e[tot].v=v; e[tot].w=w;
nxt[tot]=head[u]; head[u]=tot;
}
void dfs1(int u,int f)
{
int i,v;
siz[u]=1;
son[u]=0;
fa[u]=f;
h[u]=h[f]+1;
for(i=head[u];i;i=nxt[i])
{
v=e[i].v;
if(v!=f)
{
val[v]=e[i].w;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[son[u]]<siz[v]) son[u]=v;
}
}
}
void dfs2(int u,int f,int k)
{
int i,v;
top[u]=k;
pos[u]=++cnt;
A[cnt]=val[u];
if(!son[u]) return;
if(son[u]) dfs2(son[u],u,k);
for(i=head[u];i;i=nxt[i])
{
v=e[i].v;
if(v!=f&&v!=son[u]) dfs2(v,u,v);
}
}
void pushup(int k)
{
Sum[k]=Sum[k<<1]+Sum[k<<1|1];
Max[k]=max(Max[k<<1],Max[k<<1|1]);
Min[k]=min(Min[k<<1],Min[k<<1|1]);
}
void pushdown(int k)
{
if(lazy[k])
{
Sum[k<<1]*=-1; Sum[k<<1|1]*=-1;
swap(Max[k<<1],Min[k<<1]);
Max[k<<1]*=-1; Min[k<<1]*=-1;
swap(Max[k<<1|1],Min[k<<1|1]);
Max[k<<1|1]*=-1; Min[k<<1|1]*=-1;
lazy[k<<1]^=1; lazy[k<<1|1]^=1;
lazy[k]=0;
}
}
void build(int k,int l,int r)
{
if(l==r)
{
Sum[k]=Max[k]=Min[k]=A[l];
return;
}
int mid=(l+r)>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
pushup(k);
}
void PointChange(int k,int l,int r,int x,int v)
{
if(l==r && l==x)
{
Sum[k]=Max[k]=Min[k]=v;
return;
}
pushdown(k);
int mid=(l+r)>>1;
if(x<=mid) PointChange(k<<1,l,mid,x,v);
else PointChange(k<<1|1,mid+1,r,x,v);
pushup(k);
}
void IntervalChange(int k,int l,int r,int L,int R)
{
if(l>=L && r<=R)
{
Sum[k]*=-1; swap(Max[k],Min[k]);
Max[k]*=-1; Min[k]*=-1;
lazy[k]^=1;
return;
}
pushdown(k);
int mid=(l+r)>>1;
if(L<=mid) IntervalChange(k<<1,l,mid,L,R);
if(R>mid) IntervalChange(k<<1|1,mid+1,r,L,R);
pushup(k);
}
int IntervalSum(int k,int l,int r,int L,int R)
{
if(L<=l && r<=R)
{
return Sum[k];
}
int mid=(l+r)/2;
pushdown(k);
int ret=0;
if(L<=mid) ret+=IntervalSum(k<<1,l,mid,L,R);
if(R>mid) ret+=IntervalSum(k<<1|1,mid+1,r,L,R);
return ret;
}
int IntervalMax(int k,int l,int r,int L,int R)
{
if(L<=l && r<=R)
{
return Max[k];
}
int mid=(l+r)/2;
pushdown(k);
int ret=INT_MIN;
if(L<=mid) ret=max(IntervalSum(k<<1,l,mid,L,R),ret);
if(R>mid) ret=max(IntervalSum(k<<1|1,mid+1,r,L,R),ret);
return ret;
}
int IntervalMin(int k,int l,int r,int L,int R)
{
if(L<=l && r<=R)
{
return Min[k];
}
int mid=(l+r)/2;
pushdown(k);
int ret=INT_MAX;
if(L<=mid) ret=min(IntervalSum(k<<1,l,mid,L,R),ret);
if(R>mid) ret=min(IntervalSum(k<<1|1,mid+1,r,L,R),ret);
return ret;
}
int FindSum(int u,int v)
{
int ans=0;
while(top[u]!=top[v])
{
if(h[top[u]]<h[top[v]]) swap(u,v);
ans+=IntervalSum(1,1,n,pos[top[u]],pos[u]);
u=fa[top[u]];
}
if(h[u]<h[v]) swap(u,v);
ans+=IntervalSum(1,1,n,pos[v]+1,pos[u]);
return ans;
}
int FindMax(int u,int v)
{
int ans=INT_MIN;
while(top[u]!=top[v])
{
if(h[top[u]]<h[top[v]]) swap(u,v);
ans=max(IntervalMax(1,1,n,pos[top[u]],pos[u]),ans);
u=fa[top[u]];
}
if(h[u]<h[v]) swap(u,v);
ans=max(IntervalMax(1,1,n,pos[v]+1,pos[u]),ans);
return ans;
}
int FindMin(int u,int v)
{
int ans=INT_MAX;
while(top[u]!=top[v])
{
if(h[top[u]]<h[top[v]]) swap(u,v);
ans=min(IntervalMin(1,1,n,pos[top[u]],pos[u]),ans);
u=fa[top[u]];
}
if(h[u]<h[v]) swap(u,v);
ans=min(IntervalMin(1,1,n,pos[v]+1,pos[u]),ans);
return ans;
}
void update(int u,int v)
{
while(top[u]!=top[v])
{
if(h[top[u]]<h[top[v]]) swap(u,v);
IntervalChange(1,1,n,pos[top[u]],pos[u]);
u=fa[top[u]];
}
if(h[u]<h[v]) swap(u,v);
IntervalChange(1,1,n,pos[v]+1,pos[u]);
}
int main()
{
int u,v,w,idx;
char cmd[10];
scanf("%d",&n);
//for(int i=1;i<=n;i++) scanf("%d",&val[i]);
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&u,&v,&w);
AddEdge(u+1,v+1,w); AddEdge(v+1,u+1,w);
}
dfs1(1 , 0);
dfs2(1 ,0, 1);
build(1,1,n);
scanf("%d",&m);
while(m--)
{
scanf("%s",cmd);
if(cmd[0]=='C')
{
scanf("%d%d",&idx,&w);
u=e[idx*2-1].u; v=e[idx*2-1].v;
if(h[u]<h[v])
{
PointChange(1,1,n,pos[v],w);
}
else
{
PointChange(1,1,n,pos[u],w);
}
}
else if(cmd[0]=='N')
{
scanf("%d%d",&u,&v);
update(u+1,v+1);
}
else if(cmd[0]=='S')
{
scanf("%d%d",&u,&v);
printf("%d\n",FindSum(u+1,v+1));
}
else if(cmd[1]=='A')
{
scanf("%d%d",&u,&v);
printf("%d\n",FindMax(u+1,v+1));
}
else if(cmd[1]=='I')
{
scanf("%d%d",&u,&v);
printf("%d\n",FindMin(u+1,v+1));
}
}
return 0;
}