这题的区间维护比较麻烦,顺便复习了一下区间合并
维护区间间隔色段数,跨链时更新一下上一条链顶的颜色,去重
#include<bits/stdc++.h>
//#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<stdio.h>
#include<algorithm>
#include<queue>
#include<string.h>
#include<iostream>
#include<math.h>
#include<set>
#include<map>
#include<vector>
#include<iomanip>
using namespace std;
#define ll long long
#define pb push_back
#define FOR(a) for(int i=1;i<=a;i++)
const int inf=0x3f3f3f3f;
const int maxn=1e5+9;
int arr[maxn];
int n,q;
struct EDGE{
int v;int next;
}G[maxn<<1];
int head[maxn],tot;
void addedge(int u,int v){
++tot;G[tot].v=v;G[tot].next=head[u];head[u]=tot;
}
int top[maxn];
int pre[maxn];
int dep[maxn];
int num[maxn];
int p[maxn]; //v的对应位置
int out[maxn]; //退出时间戳
int fp[maxn]; //访问序列
int son[maxn]; //重儿子
int pos;
void init(){
memset(head,-1,sizeof head);tot=0;
memset(son,-1,sizeof son);
}
void dfs1(int u,int fa,int d){
dep[u]=d;
pre[u]=fa;
num[u]=1;
for(int i=head[u];~i;i=G[i].next){
int v=G[i].v;
if(v==fa)continue;
dfs1(v,u,d+1);
num[u]+=num[v];
if(son[u]==-1||num[v]>num[son[u]])son[u]=v;
}
}
void getpos(int u,int sp){
top[u]=sp;
p[u]=out[u]=++pos;
fp[p[u]]=u;
if(son[u]==-1)return;
getpos(son[u],sp);
for(int i=head[u];~i;i=G[i].next){
int v=G[i].v;
if(v!=son[u]&&v!=pre[u])getpos(v,v);
}
out[u]=pos;
}
struct NODE{
int lcol,rcol,sum,lazy;
}ST[maxn<<2];
void pushup(int rt){
ST[rt].sum=ST[rt<<1].sum+ST[rt<<1|1].sum;
if(ST[rt<<1].rcol==ST[rt<<1|1].lcol)ST[rt].sum--;
ST[rt].lcol=ST[rt<<1].lcol;ST[rt].rcol=ST[rt<<1|1].rcol;
}
void pushdown(int rt){
if(!ST[rt].lazy)return;
ST[rt<<1].lcol=ST[rt<<1|1].lcol=ST[rt<<1].rcol=ST[rt<<1|1].rcol=
ST[rt].lcol;
ST[rt<<1].sum=ST[rt<<1|1].sum=1;
ST[rt<<1].lazy=ST[rt<<1|1].lazy=1;
ST[rt].lazy=0;
}
void build(int l,int r,int rt){
if(l==r){ST[rt].sum=1;ST[rt].lcol=ST[rt].rcol=arr[fp[l]];return;}
int m=l+r>>1;build(l,m,rt<<1);build(m+1,r,rt<<1|1);pushup(rt);
}
void update(int a,int b,int c,int l,int r,int rt){
if(a<=l&&b>=r){
ST[rt].sum=1;
ST[rt].lcol=ST[rt].rcol=c;
ST[rt].lazy=1;
return;
}
pushdown(rt);
int m=l+r>>1;
if(a<=m)update(a,b,c,l,m,rt<<1);
if(b>m)update(a,b,c,m+1,r,rt<<1|1);
pushup(rt);
}
int L,R;
int query(int a,int b,int l,int r,int rt){
if(a==l)L=ST[rt].lcol;
if(b==r)R=ST[rt].rcol;
if(a<=l&&b>=r)return ST[rt].sum;
pushdown(rt);
int m=l+r>>1;
int ans=0;
if(b<=m)return query(a,b,l,m,rt<<1);
else if(a>m)return query(a,b,m+1,r,rt<<1|1);
if(ST[rt<<1].rcol==ST[rt<<1|1].lcol)ans--;
ans+=query(a,b,l,m,rt<<1);ans+=query(a,b,m+1,r,rt<<1|1);
return ans;
}
void solve1(int x,int y,int z){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
update(p[top[x]],p[x],z,1,n,1);
x=pre[top[x]];
}
if(dep[x]<dep[y])swap(x,y);
update(p[y],p[x],z,1,n,1);
}
void solve2(int x,int y){
int ans=0,ans1=-1,ans2=-1;//上次链的左端颜色
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]){swap(x,y);swap(ans1,ans2);}
ans+=query(p[top[x]],p[x],1,n,1);
//cout<<L<<" "<<R<<endl;
if(R==ans1)ans--;
ans1=L;x=pre[top[x]];
}
if(dep[x]<dep[y]){swap(x,y);swap(ans1,ans2);}
ans+=query(p[y],p[x],1,n,1);
if(R==ans1)ans--;if(L==ans2)ans--;
printf("%d\n",ans);
}
char op[5];
int main(){
scanf("%d%d",&n,&q);
init();
for(int i=1;i<=n;i++){scanf("%d",&arr[i]);}
for(int i=1,x,y;i<n;i++){
scanf("%d%d",&x,&y);addedge(x,y);addedge(y,x);
}
dfs1(1,1,0);getpos(1,1);
int x,y,z;
build(1,n,1);
while(q--){
scanf("%s",op);
if(op[0]=='Q'){
scanf("%d%d",&x,&y);
solve2(x,y);
}else{
scanf("%d%d%d",&x,&y,&z);
solve1(x,y,z);
}
}
}