题目
http://www.lydsy.com/JudgeOnline/problem.php?id=2243
题解
查了一上午555 因为线段树没写熟。。。
一开始竟然在pushdown函数里又写了pushdown。。。我是sb了么,然而出题人好心只T了10个点
还要多练啊
其实没啥好说的,这道题的关键就在于合并两段区间,两段区间合并,颜色个数先相加,然后再看看交界处颜色是不是一样,如果一样就减去1。
————————————————————————————————————————————————
2017.3.8
今天刷LCT的时候看到了这道题目,本来以为会很水....但我调了两节课,其实从一开始就几乎打对了,错的地方是我pushup之前没有pushdown.....还是因为splay敲得不熟练所以LCT会受到牵连
LCT的话,就维护当前的颜色,颜色段的个数,当前子树最左端的颜色,当前子树最右段的颜色,然后就没了..
代码
//树链剖分
#include <cstdio>
#include <algorithm>
#include <iostream>
#define maxn 200010
using namespace std;
struct segtree
{
int l, r, lc, rc, cnt, set;
segtree *lch, *rch;
segtree(){l=r=cnt=0;set=lc=rc=-1;lch=rch=0;}
}*root, *p1, *p2, *t;
int N, M, tid[maxn], fa[maxn], size[maxn], top[maxn], son[maxn], tim, head[maxn],
next[maxn], to[maxn], tmp[maxn], w[maxn], deep[maxn], tot;
void adde(int a, int b){to[++tot]=b;next[tot]=head[a];head[a]=tot;}
void pushdown(segtree *p)
{
if(p->set!=-1)
{
p->cnt=1;
p->lc=p->rc=p->set;
if(p->lch)p->lch->set=p->set,p->rch->set=p->set;
p->set=-1;
}
}
void update(segtree *p)
{
if(p->lch==0)return;
pushdown(p->lch),pushdown(p->rch);
p->cnt=p->lch->cnt+p->rch->cnt;
if(p->lch->rc==p->rch->lc)p->cnt--;
p->lc=p->lch->lc,p->rc=p->rch->rc;
}
void build(segtree *p, int l, int r)
{
int mid=(l+r)>>1;
p->l=l,p->r=r;
if(l==r){p->lc=p->rc=w[l],p->cnt=1;return;}
build(p->lch=new segtree,l,mid);
build(p->rch=new segtree,mid+1,r);
update(p);
}
void segset(segtree *p, int l, int r, int c)
{
int mid=(p->l+p->r)>>1;
pushdown(p);
if(l<=p->l and r>=p->r){p->set=c;return;}
if(l<=mid)segset(p->lch,l,r,c);
if(r>mid)segset(p->rch,l,r,c);
update(p);
}
int segcount(segtree *p, int l, int r, segtree *seg)
{
pushdown(p);
int mid=(p->l+p->r)>>1, ans=0;
if(l<=p->l and r>=p->r)
{
if(p->l==l)seg->lc=p->lc;
if(p->r==r)seg->rc=p->rc;
return p->cnt;
}
if(l<=mid)ans+=segcount(p->lch,l,r,seg);
if(r>mid)ans+=segcount(p->rch,l,r,seg);
if(l<=mid and r>mid and p->lch->rc==p->rch->lc)ans--;
return ans;
}
void dfs1(int pos)
{
int p;
size[pos]=1;
for(p=head[pos];p;p=next[p])
{
if(to[p]==fa[pos])continue;
fa[to[p]]=pos;
deep[to[p]]=deep[pos]+1;
dfs1(to[p]);
if(size[to[p]]>size[son[pos]])son[pos]=to[p];
size[pos]+=size[to[p]];
}
}
void dfs2(int pos, int tp)
{
int p;
top[pos]=tp;
tid[pos]=++tim;
if(son[pos])dfs2(son[pos],tp);
for(p=head[pos];p;p=next[p])
{
if(to[p]==fa[pos] or to[p]==son[pos])continue;
dfs2(to[p],to[p]);
}
}
int count(int a, int b)
{
int ta=top[a], tb=top[b];
while(ta!=tb)
{
if(deep[ta]<deep[tb])swap(a,b),swap(ta,tb),swap(p1,p2);
t->cnt=segcount(root,tid[ta],tid[a],t);
p1->cnt+=t->cnt;
if(p1->lc==t->rc)p1->cnt--;
p1->lc=t->lc;
a=fa[ta];ta=top[a];
}
if(deep[a]>deep[b])swap(a,b),swap(ta,tb),swap(p1,p2);
t->cnt=segcount(root,tid[a],tid[b],t);
p2->cnt+=t->cnt;
if(p2->lc==t->rc)p2->cnt--;
p2->lc=t->lc;
return p1->cnt+p2->cnt-(p1->lc==p2->lc);
}
void set(int a, int b, int c)
{
int ta=top[a], tb=top[b];
while(ta!=tb)
{
if(deep[ta]<deep[tb])swap(a,b),swap(ta,tb);
segset(root,tid[ta],tid[a],c);
a=fa[ta];ta=top[a];
}
if(deep[a]>deep[b])swap(a,b);
segset(root,tid[a],tid[b],c);
}
void init()
{
int i, a, b;
scanf("%d%d",&N,&M);
for(i=1;i<=N;i++)scanf("%d",tmp+i);
for(i=1;i<N;i++)scanf("%d%d",&a,&b),adde(a,b),adde(b,a);
dfs1(1);
dfs2(1,1);
for(i=1;i<=N;i++)w[tid[i]]=tmp[i];
build(root=new segtree,1,tim);
}
int main()
{
char type[10];
int a, b, c, i;
init();
p1=new segtree;
p2=new segtree;
t=new segtree;
for(i=1;i<=M;i++)
{
p1->lc=p2->lc=-1;
p1->cnt=p2->cnt=0;
scanf("%s",type);
if(*type=='C')scanf("%d%d%d",&a,&b,&c),set(a,b,c);
if(*type=='Q')scanf("%d%d",&a,&b),printf("%d\n",count(a,b));
}
return 0;
}
//LCT
#include <cstdio>
#include <algorithm>
#define maxn 500000
using namespace std;
int N;
struct node
{
int c, lc, rc, rev, set, cnt;
node *f, *ch[2];
}nd[maxn], *s[maxn];
inline int getwh(node *x)
{if(!x->f)return -1;if(x->f->ch[0]==x)return 0;if(x->f->ch[1]==x)return 1;return -1;}
inline bool isroot(node *x){return getwh(x)==-1;}
inline void join(node *x, node *y, int wh){if(x)x->f=y;if(y)y->ch[wh]=x;}
inline void rev(node *x){if(x)x->rev^=1;}
inline void set(node *x, int c){if(x)x->set=c;}
inline void pushdown(node *x)
{
if(x->rev)
{
swap(x->ch[0],x->ch[1]);
swap(x->lc,x->rc);
rev(x->ch[0]);rev(x->ch[1]);x->rev=0;
}
if(x->set)
{
x->c=x->lc=x->rc=x->set;
x->cnt=1;
set(x->ch[0],x->set),set(x->ch[1],x->set);
x->set=0;
}
}
inline void pushup(node *x)
{
if(x->ch[0])pushdown(x->ch[0]);if(x->ch[1])pushdown(x->ch[1]);
x->lc=x->rc=x->c;x->cnt=1;
if(x->ch[0])x->lc=x->ch[0]->lc,x->cnt+=x->ch[0]->cnt-(x->ch[0]->rc==x->c);
if(x->ch[1])x->rc=x->ch[1]->rc,x->cnt+=x->ch[1]->cnt-(x->ch[1]->lc==x->c);
}
inline void rotate(node *x)
{
node *y=x->f, *z=y->f; int c=getwh(x);
if(isroot(y))x->f=y->f;
else join(x,z,getwh(y));
join(x->ch[!c],y,c);
join(y,x,!c);
pushup(y),pushup(x);
}
inline void splay(node *x)
{
node *y; int top=0;
for(y=x;!isroot(y);y=y->f)s[++top]=y;s[++top]=y;
for(;top;top--)pushdown(s[top]);
while(!isroot(x))
{
y=x->f;
if(isroot(y)){rotate(x);return;}
if(getwh(x)^getwh(y))rotate(x);else rotate(y);
rotate(x);
}
}
inline void access(node *x)
{
node *t=0;
while(x)
{
splay(x);
x->ch[1]=t;
pushup(x);
t=x,x=x->f;
}
}
void makeroot(node *x){access(x);splay(x);rev(x);}
void link(node *x, node *y){makeroot(x);x->f=y;}
int main()
{
int a, b, M, i, c;
char type[5];
scanf("%d%d",&N,&M);
for(i=1;i<=N;i++)scanf("%d",&c),nd[i].c=nd[i].lc=nd[i].rc=c,nd[i].cnt=1;
for(i=1;i<N;i++)scanf("%d%d",&a,&b),link(nd+a,nd+b);
for(i=1;i<=M;i++)
{
scanf("%s%d%d",type,&a,&b);
if(*type=='C')
{
scanf("%d",&c);
makeroot(nd+a),access(nd+b),splay(nd+b);
set(nd+b,c);
}
else
{
makeroot(nd+a),access(nd+b),splay(nd+b);
printf("%d\n",nd[b].cnt);
}
}
return 0;
}