题解:树链剖分模板题,边的数值改成与边两个深度比较大的端点的数值即可,最后减去公共祖先
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cstdio>
using namespace std;
#define ls 2*rt
#define rs 2*rt+1
#define mid (L+R)/2
#define lson ls,L,mid
#define rson rs,mid+1,R
typedef long long int ll;
const int mx = 1e5+5;
struct node{
int v;
int pos;
int next;
}e[mx<<2];
int head[mx];
int top[mx];
int val[mx];
int siz[mx];
int son[mx];
int pos[mx];
int w[mx];
int fa[mx];
int id[mx];
int dep[mx];
ll sum[mx<<2];
int n,m;
int tot;
int dfn;
void init(){
tot = 0;
dfn = 0;
memset(head,0,sizeof(head));
memset(sum,0,sizeof(sum));
}
void update(int rt,int L,int R,int p,int v){
if(L==R){
sum[rt] = v;
return;
}
if(p>mid) update(rson,p,v);
else update(lson,p,v);
sum[rt] = sum[ls]+sum[rs];
}
void add(int u,int v,int pos){
tot++;
// cout<<u<<v<<endl;
e[tot].pos = pos;
e[tot].v = v;
e[tot].next = head[u];
head[u] = tot;
}
ll query(int rt,int L,int R,int l,int r){
if(L>=l&&R<=r)
return sum[rt];
if(l>mid) return query(rson,l,r);
else if(r<=mid) return query(lson,l,r);
else
return query(lson,l,mid)+query(rson,mid+1,r);
}
ll solve(int a,int b){
ll sum = 0;
while(top[a]!=top[b]){
if(dep[top[a]]<dep[top[b]])
swap(a,b);
sum +=query(1,1,n,id[top[a]],id[a]);
a = fa[top[a]];
}
if(id[a]>id[b])swap(a,b);
sum += query(1,1,n,id[a],id[b])-w[a];
return sum;
}
void dfs(int u,int pre,int d){
son[u] = 0;
siz[u] = 1;
fa[u] = pre;
dep[u] = d;
for(int i = head[u]; i; i = e[i].next){
int v = e[i].v;
if(v!=pre){
pos[e[i].pos] = v;
w[v] = val[e[i].pos];
dfs(v,u,d+1);
siz[u] += siz[v];
if(siz[son[u]]<siz[v]) son[u] = v;
}
}
}
void DFS(int u,int tp){
top[u] = tp;
id[u] = ++dfn;
update(1,1,n,dfn,w[u]);
if(son[u]!=0)
DFS(son[u],tp);
for(int i = head[u]; i; i = e[i].next){
int v = e[i].v;
if(v!=fa[u]&&v!=son[u])
DFS(v,v);
}
}
int main(){
while(scanf("%d%d",&n,&m)!=EOF){
init();
for(int i = 2; i <= n; i++){
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
val[i-1] = w;
add(u,v,i-1);
add(v,u,i-1);
}
dfs(1,1,1);
DFS(1,1);
while(m--){
int ca,a,b;
scanf("%d%d%d",&ca,&a,&b);
if(ca==0){
int v = pos[a];
w[v] = b;
update(1,1,n,id[v],b);
}
else
printf("%I64d\n",solve(a,b));
}
}
return 0;
}