题目简述:q次询问,求树上俩个点路径上,右边点权减左边点权的最大值。
思路:树剖+线段树
树剖部分,简单树剖即可,无需其他操作
int sum[N], fa[N], dep[N], son[N], top[N];
int pre[N], new_point[N];
int n;
void dfs_1(int father, int cur) {
sum[cur] = 1;
for (auto v: f[cur]) {
if (v == father) continue;
fa[v] = cur;
dep[v] = dep[cur] + 1;
dfs_1(cur, v);
sum[cur] += sum[v];
if (sum[v] > sum[son[cur]]) son[cur] = v;
}
}
int tot;//时间戳
void dfs_2(int fcur, int cur) {
top[cur] = fcur;
tot++;
pre[tot] = cur;//存旧节点编号,方便后面build函数初始化
new_point[cur] = tot;//存新节点编号,方便后面进行线段树的修改和查询
if (son[cur]) dfs_2(fcur, son[cur]);
for (auto v: f[cur]) {
if (v == fa[cur] || v == son[cur]) continue;
dfs_2(v, v);
}
}
线段树部分,首先确定需要维护几个信息:
第一个是左右端点
第二个就是区间最大和区间最小,分别是maxx,minn(用来维护下面第三个和第四个信息)
第三个根据题目定义一个lmax,表示右边节点的点权-左边节点点权的最大值
第四个则相反,定义应该rmax,表示左边节点的点权-右边节点点权的最大值(这俩个标记弱弱我一开始没想到捏,此题第一篇大佬的题解让我幡然醒悟)
然后就要想办法进行分治更新,想一下怎么从左右节点的信息更新我父节点的信息,最大和最小值弱弱我相信各位大佬都会,俺说一下第三信息怎么更新,第四个可以依次类推:
公式如上,俩个儿子没合并之前,父亲的lmax必然会等于左右儿子的lmax中的最大值,考虑合并,合并后可能对父亲lmax产生影响的,即可能可以更新父亲lmax的情况只有右儿子的区间最大值减去左儿子的区间最小值,所以直接对这三者取max即可
第五个是区间加的懒标记,表示这个区间累加的值,这样询问和修改的时候就不用每次访问每一个区间。
最后我们还要再想一下怎么处理题目里给的查询和修改操作
修改,类比树剖求lca,在俩个点跳的同时进行线段树的motify(修改)操作
void cli(int l, int r, int v) {
while (top[l] != top[r]) {
if (dep[top[l]] < dep[top[r]]) swap(l, r);
motify(1, new_point[top[l]], new_point[l], v);
l = fa[top[l]];
}
if (dep[l] > dep[r]) swap(l, r);
motify(1, new_point[l], new_point[r], v);
}
//l,r为询问给的俩个节点
//new_point数组是前面提到过的存dfs序的数组
//motify操作,涉及懒标记的下传
void pushdown(int u) {
if (xd[u].lazy) {
xd[u << 1].maxx += xd[u].lazy;
xd[u << 1].mmin += xd[u].lazy;
xd[u << 1].lazy += xd[u].lazy;
xd[u << 1 | 1].maxx += xd[u].lazy;
xd[u << 1 | 1].mmin += xd[u].lazy;
xd[u << 1 | 1].lazy += xd[u].lazy;
xd[u].lazy = 0;
}
}
void motify(int u, int l, int r, int c) {
if (xd[u].l >= l && xd[u].r <= r) {
xd[u].maxx += c;
xd[u].mmin += c;
xd[u].lazy += c;
} else {
pushdown(u);
int mid = xd[u].l + xd[u].r >> 1;
if (l <= mid) motify(u << 1, l, r, c);
if (r > mid) motify(u << 1 | 1, l, r, c);
pushup(u);
}
}
查询,类比修改操作,这里边查询边更新答案
int cli2(int l, int r) {
xds l_1, r_1;//俩个答案线段
xds temp;
l_1.lmax = l_1.rmax = 0;
l_1.maxx = -1e18;
l_1.mmin = 1e18;
r_1 = l_1;
while (top[l] != top[r]) {
if (dep[top[l]] >= dep[top[r]]) {
temp = query(1, new_point[top[l]], new_point[l]);
l = fa[top[l]];
l_1.lmax = max(max(temp.lmax, l_1.lmax), l_1.maxx - temp.mmin);
l_1.rmax = max(max(temp.rmax, l_1.rmax), temp.maxx - l_1.mmin);
l_1.maxx = max(l_1.maxx, temp.maxx);
l_1.mmin = min(l_1.mmin, temp.mmin);
} else {
temp = query(1, new_point[top[r]], new_point[r]);
r = fa[top[r]];
r_1.rmax = max(max(temp.rmax, r_1.rmax), temp.maxx - r_1.mmin);
r_1.lmax = max(max(temp.lmax, r_1.lmax), r_1.maxx - temp.mmin);
r_1.maxx = max(r_1.maxx, temp.maxx);
r_1.mmin = min(r_1.mmin, temp.mmin);
}
}
if (dep[l] < dep[r]) {
temp = query(1, new_point[l], new_point[r]);
r_1.rmax = max(max(temp.rmax, r_1.rmax), temp.maxx - r_1.mmin);
r_1.lmax = max(max(temp.lmax, r_1.lmax), r_1.maxx - temp.mmin);
r_1.maxx = max(r_1.maxx, temp.maxx);
r_1.mmin = min(r_1.mmin, temp.mmin);
} else {
temp = query(1, new_point[r], new_point[l]);
l_1.lmax = max(max(temp.lmax, l_1.lmax), l_1.maxx - temp.mmin);
l_1.rmax = max(max(temp.rmax, l_1.rmax), temp.maxx - l_1.mmin);
l_1.maxx = max(l_1.maxx, temp.maxx);
l_1.mmin = min(l_1.mmin, temp.mmin);//注意这里要先更新lmax和rmax再更新maxx和minn
//因为合并之前这俩个线段可能没有交集
}
return max(max(l_1.rmax, r_1.lmax), r_1.maxx - l_1.mmin);//最后返回一个最优值
}
//线段树query部分
xds query(int u, int l, int r) {
if (xd[u].l >= l && xd[u].r <= r) return xd[u];
else {
pushdown(u);
xds res;
int mid = xd[u].l + xd[u].r >> 1;
if (r <= mid) res = query(u << 1, l, r);
else if (l > mid) res = query(u << 1 | 1, l, r);
else {
xds l_1, r_1;
l_1 = query(u << 1, l, r);
r_1 = query(u << 1 | 1, l, r);
pushup(res, l_1, r_1);
// return res;
}
return res;
}
}
//pushup操作,返回最优答案,也就是不断更新父节点的信息操作
void pushup(xds &u, xds &l, xds &r) {
u.maxx = max(l.maxx, r.maxx);
u.mmin = min(l.mmin, r.mmin);
u.lmax = max(max(l.lmax, r.lmax), r.maxx - l.mmin);
u.rmax = max(max(l.rmax, r.rmax), l.maxx - r.mmin);
}
总结:弱弱我还得练
ps:完整代码:
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N = 5e4 + 10;
struct xds {
int l, r;
int maxx;
int mmin;
int lazy;
int lmax;
int rmax;
} xd[4 * N];
vector<int> f[N];
int val[N];
int sum[N], fa[N], dep[N], son[N], top[N];
int pre[N], new_point[N];
int n;
void dfs_1(int father, int cur) {
sum[cur] = 1;
for (auto v: f[cur]) {
if (v == father) continue;
fa[v] = cur;
dep[v] = dep[cur] + 1;
dfs_1(cur, v);
sum[cur] += sum[v];
if (sum[v] > sum[son[cur]]) son[cur] = v;
}
}
int tot;
void dfs_2(int fcur, int cur) {
top[cur] = fcur;
tot++;
pre[tot] = cur;
new_point[cur] = tot;
if (son[cur]) dfs_2(fcur, son[cur]);
for (auto v: f[cur]) {
if (v == fa[cur] || v == son[cur]) continue;
dfs_2(v, v);
}
}
void pushup(int u) {
xd[u].maxx = max(xd[u << 1].maxx, xd[u << 1 | 1].maxx);
xd[u].mmin = min(xd[u << 1].mmin, xd[u << 1 | 1].mmin);
xd[u].lmax = max(max(xd[u << 1].lmax, xd[u << 1 | 1].lmax), xd[u << 1 | 1].maxx - xd[u << 1].mmin);
xd[u].rmax = max(max(xd[u << 1].rmax, xd[u << 1 | 1].rmax), xd[u << 1].maxx - xd[u << 1 | 1].mmin);
}
void build(int u, int l, int r) {
if (l == r) xd[u] = {l, r, val[pre[l]], val[pre[l]], 0, 0, 0};
else {
xd[u] = {l, r, 0, 0, 0, 0, 0};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void pushdown(int u) {
if (xd[u].lazy) {
xd[u << 1].maxx += xd[u].lazy;
xd[u << 1].mmin += xd[u].lazy;
xd[u << 1].lazy += xd[u].lazy;
xd[u << 1 | 1].maxx += xd[u].lazy;
xd[u << 1 | 1].mmin += xd[u].lazy;
xd[u << 1 | 1].lazy += xd[u].lazy;
xd[u].lazy = 0;
}
}
void motify(int u, int l, int r, int c) {
if (xd[u].l >= l && xd[u].r <= r) {
xd[u].maxx += c;
xd[u].mmin += c;
xd[u].lazy += c;
} else {
pushdown(u);
int mid = xd[u].l + xd[u].r >> 1;
if (l <= mid) motify(u << 1, l, r, c);
if (r > mid) motify(u << 1 | 1, l, r, c);
pushup(u);
}
}
void pushup(xds &u, xds &l, xds &r) {
u.maxx = max(l.maxx, r.maxx);
u.mmin = min(l.mmin, r.mmin);
u.lmax = max(max(l.lmax, r.lmax), r.maxx - l.mmin);
u.rmax = max(max(l.rmax, r.rmax), l.maxx - r.mmin);
}
xds query(int u, int l, int r) {
if (xd[u].l >= l && xd[u].r <= r) return xd[u];
else {
pushdown(u);
xds res;
int mid = xd[u].l + xd[u].r >> 1;
if (r <= mid) res = query(u << 1, l, r);
else if (l > mid) res = query(u << 1 | 1, l, r);
else {
xds l_1, r_1;
l_1 = query(u << 1, l, r);
r_1 = query(u << 1 | 1, l, r);
pushup(res, l_1, r_1);
// return res;
}
return res;
}
}
void cli(int l, int r, int v) {
while (top[l] != top[r]) {
if (dep[top[l]] < dep[top[r]]) swap(l, r);
motify(1, new_point[top[l]], new_point[l], v);
l = fa[top[l]];
}
if (dep[l] > dep[r]) swap(l, r);
motify(1, new_point[l], new_point[r], v);
}
int cli2(int l, int r) {
xds l_1, r_1;
xds temp;
l_1.lmax = l_1.rmax = 0;
l_1.maxx = -1e18;
l_1.mmin = 1e18;
r_1 = l_1;
while (top[l] != top[r]) {
if (dep[top[l]] >= dep[top[r]]) {
temp = query(1, new_point[top[l]], new_point[l]);
l = fa[top[l]];
l_1.lmax = max(max(temp.lmax, l_1.lmax), l_1.maxx - temp.mmin);
l_1.rmax = max(max(temp.rmax, l_1.rmax), temp.maxx - l_1.mmin);
l_1.maxx = max(l_1.maxx, temp.maxx);
l_1.mmin = min(l_1.mmin, temp.mmin);
} else {
temp = query(1, new_point[top[r]], new_point[r]);
r = fa[top[r]];
r_1.rmax = max(max(temp.rmax, r_1.rmax), temp.maxx - r_1.mmin);
r_1.lmax = max(max(temp.lmax, r_1.lmax), r_1.maxx - temp.mmin);
r_1.maxx = max(r_1.maxx, temp.maxx);
r_1.mmin = min(r_1.mmin, temp.mmin);
}
}
if (dep[l] < dep[r]) {
temp = query(1, new_point[l], new_point[r]);
r_1.rmax = max(max(temp.rmax, r_1.rmax), temp.maxx - r_1.mmin);
r_1.lmax = max(max(temp.lmax, r_1.lmax), r_1.maxx - temp.mmin);
r_1.maxx = max(r_1.maxx, temp.maxx);
r_1.mmin = min(r_1.mmin, temp.mmin);
} else {
temp = query(1, new_point[r], new_point[l]);
l_1.lmax = max(max(temp.lmax, l_1.lmax), l_1.maxx - temp.mmin);
l_1.rmax = max(max(temp.rmax, l_1.rmax), temp.maxx - l_1.mmin);
l_1.maxx = max(l_1.maxx, temp.maxx);
l_1.mmin = min(l_1.mmin, temp.mmin);
}
return max(max(l_1.rmax, r_1.lmax), r_1.maxx - l_1.mmin);
}
void solve() {
cin >> n;
for (int i = 1; i <= n; i++) cin >> val[i];
for (int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
f[u].push_back(v);
f[v].push_back(u);
}
dfs_1(0, 1);
dfs_2(1, 1);
build(1, 1, n);
int q;
cin >> q;
while (q--) {
int a, b, v;
cin >> a >> b >> v;
cli(a, b, v);
int ans = cli2(a, b);
if (ans > 0) cout << ans << '\n';
else cout << 0 << '\n';
}
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr), cout.tie(nullptr);
int T = 1;
// cin >> T;
while (T--) { solve(); }
return 0;
}