树链剖分经典板子题,但是需要注意的是线段树既要维护和还要维护区间最大值。。。
第一次手搓还是很难。。感觉还是不太熟练。。。
以下是 A C代码
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;
#define ll long long int
inline ll read()
{
ll x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
#define ls k<<1,l,mid
#define rs k<<1|1,mid+1,r
#define root 1,1,n
ll v[maxn<<2];
ll n,m;
int head[maxn],tot;
struct node
{
int nxt,to;
}ed[maxn];
void add(int a,int b)
{
ed[++tot].to = b; ed[tot].nxt = head[a]; head[a] = tot;
ed[++tot].to = a; ed[tot].nxt = head[b]; head[b] = tot;
}
ll top[maxn],pos[maxn],dep[maxn];
ll sz[maxn],fa[maxn],son[maxn];
int id;
void dfs(int x)
{
sz[x]=1;
for(int i=head[x];i;i=ed[i].nxt)
{
int to=ed[i].to;
if(to==fa[x]) continue;
dep[to] = dep[x]+1;
fa[to] = x;
dfs(to);
sz[x] += sz[to];
if (sz[to]>sz[son[x]])
son[x]=to;
}
}
void dfss(int x,int p)
{
top[x] = p; pos[x] = ++id;
if(son[x]) dfss(son[x], p);
for(int i=head[x];i;i=ed[i].nxt)
{
int to = ed[i].to;
if(to == fa[x] || to == son[x]) continue;
dfss(to, to);
}
}
ll num[maxn<<2],mxx[maxn<<2];
void pb(int k)
{
num[k] = num[k<<1] + num[k<<1|1];
mxx[k] = max( mxx[k<<1], mxx[k<<1|1]);
}
void update(int k,int l,int r,int id,ll v)
{
if(l == r)
{
num[k] = mxx[k] = v;
return;
}
int mid = (l+r)>>1;
if(id <= mid)update(ls,id,v);
else update(rs,id,v);
pb(k);
}
ll sum(int k,int l,int r,int s,int e)
{
if(s <= l && r <= e)
return num[k];
int mid = (l+r)>>1;
ll res = 0;
if(s <= mid)
res += sum(ls,s,e);
if(mid < e)
res += sum(rs,s,e);
return res;
}
ll mxv(int k,int l,int r,int s,int e)
{
if(s <= l && r <= e)
return mxx[k];
int mid = (l+r)>>1;
ll res = -0x3f3f3f3f;
if(s <= mid)
res = max(res, mxv(ls,s,e));
if(e > mid)
res = max(res, mxv(rs,s,e));
return res;
}
ll query_sum(ll x, ll y)
{
ll ans = 0;
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]]) swap(x, y);
ans += sum(root, pos[top[x]], pos[x]);
x = fa[top[x]];
}
if(dep[x] < dep[y]) swap(x, y);
ans += sum(root, pos[y], pos[x]);
return ans;
}
ll query_max(ll x, ll y)
{
ll ans = -0x3f3f3f3f;
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]]) swap(x, y);
ans = max(ans, mxv(root, pos[top[x]], pos[x]));
x = fa[top[x]];
}
if(dep[x] < dep[y]) swap(x, y);
ans = max(ans, mxv(root, pos[y], pos[x]));
return ans;
}
int main()
{
memset(head, 0, sizeof head);
tot = 0;
memset(v, 0, sizeof v);
n=read();
for(int i=1;i<n;i++)
{
ll a=read(), b=read();
add(a, b);
}
for(int i=1;i<=n;i++) v[i] = read();
dep[1] = 1; fa[1] = 1;
dfs(1);
dfss(1, 1);
for(int i=1;i<=n;i++)
update(root,pos[i],v[i]);
m=read();
char st[10];
ll a,b;
for(int i=1;i<=m;i++)
{
cin>>st;
a=read(),b=read();
if(st[1] == 'M')
{
printf("%lld\n",query_max(a,b));
}
if(st[1] == 'S')
{
printf("%lld\n",query_sum(a, b));
}
if(st[1] == 'H')
{
update(root, pos[a], b);
}
}
return 0;
}