线段树部分还要带两个成员lc(该区间最左边的结点颜色)和rc(该区间最右边的结点颜色) ,其他应该都是裸的树链剖分
我被卡死是在询问的时候,最后top[u]==top[v]的时候,要判断两边的color和之前的两条链的顶端是不是一样(我把之前两条链的顶端记反了)
代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstdlib>
#include<cstring>
#define lo o<<1
#define ro o<<1|1
#define mid ((l+r)>>1)
using namespace std;
const int N=100010;
int n, m, col[N], rt=1;
int hson[N], id[N], real[N], top[N], siz[N], dep[N], fa[N], tot;
int head[N], next[N<<1], to[N<<1], now;
void add(int u, int v) {to[++now]=v; next[now]=head[u];head[u]=now;}
int sum[N<<3], lc[N<<3], rc[N<<3], laz[N<<3];
void pushup(int o) {sum[o]=sum[lo]+sum[ro];if(rc[lo] == lc[ro]) sum[o]--; lc[o]=lc[lo], rc[o]=rc[ro];}
void pushdown(int o)
{
if(!laz[o]) return ; sum[lo]=sum[ro]=1;
lc[lo]=rc[lo]=lc[ro]=rc[ro]=laz[lo]=laz[ro]=laz[o], laz[o]=0;
}
void build(int o, int l, int r)
{
if(l == r) {lc[o]=rc[o]=col[real[l]]; sum[o]=1; return ;}
build(lo, l, mid); build(ro, mid+1, r); pushup(o);
}
void dye(int o, int l, int r, int L, int R, int c)
{
pushdown(o); if(l > R || r < L) return ;
if(l >= L && r <= R) {laz[o]=lc[o]=rc[o]=c, sum[o]=1; return ;}
dye(lo, l, mid, L, R, c); dye(ro, mid+1, r, L, R, c); pushup(o);
}
int query(int o, int l, int r, int L, int R, int &c1, int &c2)
{
pushdown(o); if(l > R || r < L) return 0; if(l == L) c1=lc[o]; if(r == R) c2=rc[o];
if(l >= L && r <= R) return sum[o];
if(R <= mid) return query(lo, l, mid, L, R, c1, c2);
if(L > mid) return query(ro, mid+1, r, L, R, c1, c2);
int ans=query(lo, l, mid, L, R, c1, c2)+query(ro, mid+1, r, L, R, c1, c2);
if(rc[lo] == lc[ro]) ans--;
return ans;
}
void dfs1(int u, int f)
{
fa[u]=f; siz[u]=1; int v;
for(int i=head[u]; i; i=next[i]) if((v=to[i]) != f)
{
dep[v]=dep[u]+1; dfs1(v, u); siz[u]+=siz[v];
if(!hson[u] || siz[v] > siz[hson[u]]) hson[u]=v;
}
}
void dfs2(int u, int anc)
{
top[u]=anc; id[u]=++tot; real[tot]=u; int v;
if(!hson[u]) return ; dfs2(hson[u], anc);
for(int i=head[u]; i; i=next[i]) if((v=to[i]) != fa[u] && v != hson[u]) dfs2(v, v);
}
void Dye(int u, int v, int c)
{
int tu=top[u], tv=top[v];
while(tu != tv)
{
if(dep[tu] < dep[tv]) swap(tu, tv), swap(u, v);
dye(1, 1, n, id[tu], id[u], c); u=fa[tu], tu=top[u];
}
if(dep[u] > dep[v]) swap(u, v); dye(1, 1, n, id[u], id[v], c);
}
int Query(int u, int v)
{
int tu=top[u], tv=top[v], ans=0, tc1=0, tc2=0, Lc, Rc;
while(tu != tv)
{
if(dep[tu] < dep[tv]) swap(tu, tv), swap(u, v), swap(tc1, tc2);
ans+=query(1, 1, n, id[tu], id[u], Lc, Rc); if(Rc == tc1) ans--;
tc1=Lc; u=fa[tu]; tu=top[u];
}
if(dep[u] > dep[v]) swap(u, v), swap(tc1, tc2); ans+=query(1, 1, n, id[u], id[v], Lc, Rc);
if(Lc == tc1) ans--; if(Rc == tc2) ans--; return ans;//就是这里卡死我
}
int read(){
int out=0, f=1; char c=getchar(); while(c < '0' || c > '9') {if(c == '-') f=-1; c=getchar();}
while(c >= '0' && c <= '9') {out=(out<<1)+(out<<3)+c-'0'; c=getchar();}
return out*f;
}
void solve()
{
int u, v, c;
for(int i=1; i <= m; i++)
{
char cmd=getchar(); while(cmd != 'Q' && cmd != 'C') cmd=getchar(); u=read(), v=read();
if(cmd == 'C') c=read(), Dye(u, v, c);
if(cmd == 'Q') printf("%d\n", Query(u, v));
}
}
void init()
{
n=read(), m=read(); int u, v;
for(int i=1; i <= n; i++) col[i]=read();
for(int i=1; i < n; i++) u=read(), v=read(), add(u, v), add(v, u);
dfs1(rt, 0); dfs2(rt, rt); build(1, 1, n);
}
int main()
{
init(); solve();
return 0;
}
纪念一下至今为止错的最多的