题目很好理解
segtree不好维护,写了好久的
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>
using namespace std;
const int M = 1e5+10;
int head[M], cnt;
struct Edge{
int u, v, next;
void set(int _u, int _v){
u = _u, v = _v;
next = head[u];
head[u] = cnt++;
}
}edge[M << 1];
struct Node{
int l, r, n, ll, rr, f;
}node[M << 4];
int pre[M], son[M], siz[M], top[M], dep[M], pos[M], tot;
int s[M], n, m;
void build(int l, int r, int p){
node[p].l = l, node[p].r = r;
node[p].f = 0, node[p].n = 0;
if(l == r)return ;
int mid = (l+r) >> 1;
build(l, mid, p << 1);
build(mid+1, r, p << 1|1);
}
void update1(int k, int c, int p){
if(node[p].l == node[p].r){
node[p].ll = c;
node[p].rr = c;
node[p].n = 1;
return ;
}
int mid = (node[p].l+node[p].r) >> 1;
update1(k, c, p << 1|(k>mid));
}
void update2(int l, int r, int p){
if(l == r)return ;
int mid = (l+r) >> 1;
update2(l, mid, p << 1);
update2(mid+1, r, p << 1|1);
node[p].ll = node[p << 1].ll;
node[p].rr = node[p << 1|1].rr;
node[p].n = node[p << 1].n+node[p << 1|1].n-(node[p << 1].rr == node[p << 1|1].ll);
}
void dfs_1(int u, int f, int d){
pre[u] = f, dep[u] = d;
siz[u] = 1, son[u] = 0;
for(int i = head[u]; ~i; i = edge[i].next){
int v = edge[i].v;
if(v != f){
dfs_1(v, u, d+1);
if(siz[son[u]]<siz[v])son[u] = v;
siz[u] += siz[v];
}
}
}
void dfs_2(int u, int tp){
top[u] = tp, pos[u] = ++tot;
if(son[u])dfs_2(son[u], tp);
for(int i = head[u]; ~i; i = edge[i].next){
int v = edge[i].v;
if(v != pre[u] && v != son[u])dfs_2(v, v);
}
}
void pushdown(int p){
node[p].f = 0;
node[p << 1].f = 1;
node[p << 1].ll = node[p].ll;
node[p << 1].rr = node[p].rr;
node[p << 1].n = 1;
node[p << 1|1].f = 1;
node[p << 1|1].ll = node[p].ll;
node[p << 1|1].rr = node[p].rr;
node[p << 1|1].n = 1;
}
int query(int l, int r, int p){
if(node[p].l == l && node[p].r == r){
return node[p].n;
}
if(node[p].f)pushdown(p);
int mid = (node[p].l+node[p].r) >> 1, ret = 0;
if(l <= mid){
ret += query(l, min(r, mid), p << 1);
}
if(r>mid){
ret += query(max(l, mid+1), r, p << 1|1);
}
return ret-(l <= mid && r>mid && node[p << 1].rr == node[p << 1|1].ll);
}
int getcol(int k, int p){
if(node[p].f){
return node[p].ll;
}
if(node[p].l == node[p].r){
return node[p].ll;
}
int mid = (node[p].l+node[p].r) >> 1;
return getcol(k, p << 1|(k>mid));
}
int Qsum(int u, int v){
int f1 = top[u], f2 = top[v];
int ru = -1, rv = -1, sum = 0, tmp;
while(f1 != f2){
if(dep[f1]>dep[f2]){
sum += query(pos[f1], pos[u], 1);
sum -= (getcol(pos[u], 1) == ru);
ru = getcol(pos[f1], 1);
u = pre[f1], f1 = top[u];
}
else{
sum += query(pos[f2], pos[v], 1);
sum -= (getcol(pos[v], 1) == rv);
rv = getcol(pos[f2], 1);
v = pre[f2], f2 = top[v];
}
}
if(dep[u]<dep[v]){
tmp = query(pos[u], pos[v], 1);
}
else{
tmp = query(pos[v], pos[u], 1);
}
sum += tmp;
sum -= (ru == getcol(pos[u], 1));
sum -= (rv == getcol(pos[v], 1));
return sum;
}
void update(int l, int r, int p, int c){
if(node[p].l == l && node[p].r == r){
node[p].ll = c, node[p].rr = c;
node[p].n = 1, node[p].f = 1;
return ;
}
if(node[p].f)pushdown(p);
int mid = (node[p].l+node[p].r) >> 1;
if(l <= mid){
update(l, min(r, mid), p << 1, c);
}
if(r>mid){
update(max(l, mid+1), r, p << 1|1, c);
}
node[p].ll = node[p << 1].ll;
node[p].rr = node[p << 1|1].rr;
node[p].n = node[p << 1].n+node[p << 1|1].n-(node[p << 1].rr == node[p << 1|1].ll);
}
void change(int u, int v, int c){
int f1 = top[u], f2 = top[v];
while(f1 != f2){
if(dep[f1]<dep[f2]){
swap(f1, f2), swap(u, v);
}
update(pos[f1], pos[u], 1, c);
u = pre[f1], f1 = top[u];
}
if(dep[u]<dep[v])swap(u, v);
update(pos[v], pos[u], 1, c);
}
int main(){
while(cin >> n >> m){
cnt = 0, tot = 0;
memset(head, -1, sizeof(head));
for(int i = 1; i <= n; i++){
scanf("%d", &s[i]);
}
build(1, n, 1);
for(int i = 1, u, v; i<n; i++){
scanf("%d%d", &u, &v);
edge[cnt].set(u, v);
edge[cnt].set(v, u);
}
dfs_1((n+1)/2, 0, 1);
dfs_2((n+1)/2, (n+1)/2);
for(int i = 1; i <= n; i++){
update1(pos[i], s[i], 1);
}
update2(1, n, 1);
while(m--){
char op[5];
scanf("%s", op);
if(op[0] == 'Q'){
int u, v;
scanf("%d%d", &u, &v);
printf("%d\n", Qsum(u, v));
}
else{
int u, v, c;
scanf("%d%d%d", &u, &v, &c);
change(u, v, c);
}
}
}
return 0;
}