原题传送门
其实也是一道树剖裸题啊
四个操作
- Change:单点修改
- Cover:区间覆盖
- Add:区间加
- Max:区间最值
是不是很裸啊,不过此题有两个特点,也是需要特别注意的地方
- 维护边权,因为线段树维护的是点,所以我们采用以点代边的方法,用一条边两端点中深度更大的那一个点代表这条边,最终跳链跳到两点处于同一条链上时,深度更小的那个点需要id+1
- 区间覆盖操作,要处理覆盖与加法之间的平衡比较麻烦,要全局考虑。这里线段树维护三个值,sum表示区间的最大值,add表示加法的lazytag,tag表示覆盖的lazytag
覆盖时,需要把add置为0
pushdown时,先下传tag(同时把子节点的add置为0),再下传add
码量其实也不是很大,156行足矣~
Code:
#include <bits/stdc++.h>
#define maxn 100010
#define ls rt << 1
#define rs rt << 1 | 1
using namespace std;
struct Edge{
int to, next;
}edge[maxn << 1];
struct Seg{
int l, r, sum, tag, add;
}seg[maxn << 2];
struct Line{
int x, y, z;
}line[maxn];
int head[maxn], num, d[maxn], fa[maxn], size[maxn], son[maxn], id[maxn], cnt, top[maxn], w[maxn], wt[maxn], n;
inline int read(){
int s = 0, w = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
return s * w;
}
void add_edge(int x, int y){ edge[++num].to = y; edge[num].next = head[x]; head[x] = num; }
void dfs(int u){
size[u] = 1, son[u] = -1;
for (int i = head[u]; i; i = edge[i].next){
int v = edge[i].to;
if (v != fa[u]){
fa[v] = u, d[v] = d[u] + 1;
dfs(v);
size[u] += size[v];
if (son[u] == -1 || son[u] != -1 && size[son[u]] < size[v]) son[u] = v;
}
}
}
void dfs(int u, int x){
id[u] = ++cnt, top[u] = x, wt[cnt] = w[u];
if (son[u] == -1) return;
dfs(son[u], x);
for (int i = head[u]; i; i = edge[i].next){
int v = edge[i].to;
if (v != fa[u] && v != son[u]) dfs(v, v);
}
}
void pushup(int rt){ seg[rt].sum = max(seg[ls].sum, seg[rs].sum); }
void pushdown(int rt){
if (seg[rt].tag != -1){
seg[ls].sum = seg[rs].sum = seg[rt].tag;
seg[ls].add = seg[rs].add = 0;
seg[ls].tag = seg[rs].tag = seg[rt].tag;
}
if (seg[rt].add){
seg[ls].sum += seg[rt].add;
seg[rs].sum += seg[rt].add;
seg[ls].add += seg[rt].add;
seg[rs].add += seg[rt].add;
}
seg[rt].add = 0, seg[rt].tag = -1;
}
void build(int rt, int l, int r){
seg[rt].l = l, seg[rt].r = r, seg[rt].tag = -1;
if (l == r){
seg[rt].sum = wt[l]; return;
}
int mid = (l + r) >> 1;
build(ls, l, mid); build(rs, mid + 1, r);
pushup(rt);
}
void update(int rt, int l, int r, int k){
pushdown(rt);
if (seg[rt].l > r || seg[rt].r < l) return;
if (seg[rt].l >= l && seg[rt].r <= r){
seg[rt].sum += k, seg[rt].add += k; return;
}
update(ls, l, r, k); update(rs, l, r, k);
pushup(rt);
}
void change(int rt, int l, int r, int k){
pushdown(rt);
if (seg[rt].l > r || seg[rt].r < l) return;
if (seg[rt].l >= l && seg[rt].r <= r){
seg[rt].sum = k, seg[rt].add = 0, seg[rt].tag = k; return;
}
change(ls, l, r, k); change(rs, l, r, k);
pushup(rt);
}
int query(int rt, int l, int r){
if (seg[rt].l > r || seg[rt].r < l) return 0;
if (seg[rt].l >= l && seg[rt].r <= r) return seg[rt].sum;
pushdown(rt);
return max(query(ls, l, r), query(rs, l, r));
}
int main(){
n = read();
for (int i = 1; i < n; ++i){
int x = read(), y = read(), z = read();
add_edge(x, y); add_edge(y, x);
line[i].x = x; line[i].y = y; line[i].z = z;
}
dfs(1);
for (int i = 1; i < n; ++i) w[fa[line[i].x] == line[i].y ? line[i].x : line[i].y] = line[i].z;
dfs(1, 1);
build(1, 1, n);
while (1){
char c = getchar(); for (; c != 'a' && c != 'o' && c != 'd' && c != 'h' && c != 't'; c = getchar());
if (c == 't') break;
if (c == 'h'){
int x = read(), y = read();
int z = fa[line[x].x] == line[x].y ? line[x].x : line[x].y;
change(1, id[z], id[z], y);
}
if (c == 'o'){
int x = read(), y = read(), z = read();
while (top[x] != top[y]){
if (d[top[x]] < d[top[y]]) swap(x, y);
change(1, id[top[x]], id[x], z);
x = fa[top[x]];
}
if (d[x] < d[y]) swap(x, y);
if (id[y] + 1 <= id[x]) change(1, id[y] + 1, id[x], z);
}
if (c == 'd'){
int x = read(), y = read(), z = read();
while (top[x] != top[y]){
if (d[top[x]] < d[top[y]]) swap(x, y);
update(1, id[top[x]], id[x], z);
x = fa[top[x]];
}
if (d[x] < d[y]) swap(x, y);
if (id[y] + 1 <= id[x]) update(1, id[y] + 1, id[x], z);
}
if (c == 'a'){
int x = read(), y = read(), ans = 0;
while (top[x] != top[y]){
if (d[top[x]] < d[top[y]]) swap(x, y);
ans = max(ans, query(1, id[top[x]], id[x]));
x = fa[top[x]];
}
if (d[x] < d[y]) swap(x, y);
if (id[y] + 1 <= id[x]) ans = max(ans, query(1, id[y] + 1, id[x]));
printf("%d\n", ans);
}
}
return 0;
}