染色
Time Limit: 20 Sec Memory Limit: 512 MBSubmit: 3626 Solved: 1380
[ Submit][ Status][ Discuss]
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
思路:树链刨分,就是线段树敲起来比较麻烦,每次查询一个区间有多少段,并且记录这个区间的左右端点值,然后把这个路径上的这些区间的查询结果放在一个vector里面,再依次判断他们的相邻部分是否相同。
AC代码如下:
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
struct node
{
int u,v;
}arr[100010];
struct node2
{
int u,v,next;
}edge[200010];
struct node3
{
int l,r,part,lazy,L,R;
}tree[400010];
struct node4
{
int part,L,R;
};
vector<node4> vc1,vc2;
node4 A,B;
int T,t,n,m,tot,tot2,Head[100010];
int siz[100010],son[100010],fa[100010],depth[100010],top[100010],p[100010];
int c[100010],c2[100010];
char s[10];
void add(int u,int v)
{
edge[tot].u=u;
edge[tot].v=v;
edge[tot].next=Head[u];
Head[u]=tot++;
}
void dfs1(int u)
{
int i,j,k,v;
siz[u]=1;
son[u]=0;
for(i=Head[u];i!=-1;i=edge[i].next)
{
v=edge[i].v;
if(v==fa[u])
continue;
fa[v]=u;
depth[v]=depth[u]+1;
dfs1(v);
if(siz[v]>siz[son[u]])
son[u]=v;
siz[u]+=siz[v];
}
}
void dfs2(int u,int f)
{
int i,j,v;
p[u]=++tot2;
top[u]=f;
if(son[u]!=0)
dfs2(son[u],f);
for(i=Head[u];i!=-1;i=edge[i].next)
{
v=edge[i].v;
if(v==fa[u] || v==son[u])
continue;
dfs2(v,v);
}
}
void op(int tr,int num)
{
tree[tr].lazy=num;
tree[tr].part=1;
tree[tr].L=tree[tr].R=num;
}
void up(int tr)
{
tree[tr].part=tree[tr*2].part+tree[tr*2+1].part;
if(tree[tr*2].R==tree[tr*2+1].L)
tree[tr].part--;
tree[tr].L=tree[tr*2].L;
tree[tr].R=tree[tr*2+1].R;
}
void down(int tr)
{
if(tree[tr].lazy>=0)
{
op(tr*2,tree[tr].lazy);
op(tr*2+1,tree[tr].lazy);
tree[tr].lazy=-1;
}
}
void build(int l,int r,int tr)
{
tree[tr].l=l;
tree[tr].r=r;
tree[tr].lazy=-1;
if(l==r)
{
op(tr,c2[l]);
return;
}
int mi=(l+r)/2;
build(l,mi,tr*2);
build(mi+1,r,tr*2+1);
up(tr);
}
void update(int l,int r,int tr,int num)
{
if(tree[tr].l==l && tree[tr].r==r)
{
op(tr,num);
return;
}
down(tr);
int mi=(tree[tr].l+tree[tr].r)/2;
if(r<=mi)
update(l,r,tr*2,num);
else if(l>mi)
update(l,r,tr*2+1,num);
else
{
update(l,mi,tr*2,num);
update(mi+1,r,tr*2+1,num);
}
up(tr);
}
void query(int l,int r,int tr,int &part,int &L,int &R)
{
if(tree[tr].l==l && tree[tr].r==r)
{
part=tree[tr].part;
L=tree[tr].L;
R=tree[tr].R;
return;
}
down(tr);
int mi=(tree[tr].l+tree[tr].r)/2;
if(r<=mi)
query(l,r,tr*2,part,L,R);
else if(l>mi)
query(l,r,tr*2+1,part,L,R);
else
{
int part1,part2,L1,L2,R1,R2;
query(l,mi,tr*2,part1,L1,R1);
query(mi+1,r,tr*2+1,part2,L2,R2);
part=part1+part2;
if(R1==L2)
part--;
L=L1;
R=R2;
}
}
void solve_c(int u,int v,int num)
{
int f1=top[u],f2=top[v];
while(f1!=f2)
{
if(depth[f1]<depth[f2])
{
swap(f1,f2);
swap(u,v);
}
update(p[f1],p[u],1,num);
u=fa[f1];
f1=top[u];
}
if(depth[u]>depth[v])
swap(u,v);
update(p[u],p[v],1,num);
}
int solve_q(int u,int v)
{
vc1.clear();
vc2.clear();
int f1=top[u],f2=top[v],i,j,k,part,L,R;
while(f1!=f2)
{
if(depth[f1]>depth[f2])
{
query(p[f1],p[u],1,A.part,A.R,A.L);
vc1.push_back(A);
u=fa[f1];
f1=top[u];
}
else
{
query(p[f2],p[v],1,A.part,A.L,A.R);
vc2.push_back(A);
v=fa[f2];
f2=top[v];
}
}
if(depth[v]>=depth[u])
{
query(p[u],p[v],1,A.part,A.L,A.R);
vc1.push_back(A);
}
else
{
query(p[v],p[u],1,A.part,A.R,A.L);
vc1.push_back(A);
}
for(i=vc2.size()-1;i>=0;i--)
vc1.push_back(vc2[i]);
part=vc1[0].part;
R=vc1[0].R;
for(i=1;i<vc1.size();i++)
{
part+=vc1[i].part;
if(R==vc1[i].L)
part--;
R=vc1[i].R;
}
return part;
}
int main()
{
int i,j,k,u,v,num;
while(~scanf("%d%d",&n,&m))
{
memset(Head,-1,sizeof(Head));
tot=tot2=0;
for(i=1;i<=n;i++)
scanf("%d",&c[i]);
for(i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
depth[1]=1;
dfs1(1);
dfs2(1,1);
for(i=1;i<=n;i++)
c2[p[i]]=c[i];
build(1,tot2,1);
while(m--)
{
scanf("%s",s+1);
if(s[1]=='C')
{
scanf("%d%d%d",&u,&v,&num);
solve_c(u,v,num);
}
else
{
scanf("%d%d",&u,&v);
num=solve_q(u,v);
printf("%d\n",num);
}
}
}
}