题目大意:给定一棵树,点染色,修改点权,链上询问点权和或者最大值
数据规模10^5
主要思想是对每种颜色开一颗线段树,动态分配内存,染色时删除原来的点,加入到新线段树中即可。
需要注意的是不仅插入要更新信息,删除也要更新信息。
自己的代码:
#include<cstdio>
#include<cstring>
#define safe(x,a) (x?x->a:0)
#define gm 100001
using namespace std;
const int size=1<<16;
int p,v,ll,rr;
int n,q;
inline int __max(const int &a,const int &b)
{
return a<b?b:a;
}
struct node
{
node *l,*r;
int max,sum;
inline void* operator new(size_t);
inline void operator delete(void*);
void change(int x,int y)
{
if(x==y)
{
max=sum=v;
return;
}
int mid=x+y>>1;
if(p<=mid)
{
if(!l) l=new node;
l->change(x,mid);
}
else
{
if(!r) r=new node;
r->change(mid+1,y);
}
max=__max(safe(l,max),safe(r,max));
sum=safe(l,sum)+safe(r,sum);
}
static void free(node*& kore,int x,int y)
{
if(x==y)
{
delete kore;
kore=0;
return;
}
int mid=x+y>>1;
if(p<=mid) free(kore->l,x,mid);
else free(kore->r,mid+1,y);
if(!kore->l&&!kore->r)
{
delete kore;
kore=0;
return;
}
kore->max=__max(safe(kore->l,max),safe(kore->r,max));
kore->sum=safe(kore->l,sum)+safe(kore->r,sum);
}
void getsum(int x,int y)
{
if(ll<=x&&y<=rr)
{
v+=sum;
return;
}
int mid=x+y>>1;
if(ll<=mid&&l) l->getsum(x,mid);
if(mid<rr&&r) r->getsum(mid+1,y);
}
void getmax(int x,int y)
{
if(ll<=x&&y<=rr)
{
v=__max(v,max);
return;
}
int mid=x+y>>1;
if(ll<=mid&&l) l->getmax(x,mid);
if(mid<rr&&r) r->getmax(mid+1,y);
}
}*S,*T,*F[size];
struct ar
{
node* rt;
ar():rt(0){}
void change(int pos,int val)
{
if(!rt) rt=new node;
p=pos;v=val;
rt->change(1,n);
}
void free(int pos)
{
p=pos;
node::free(rt,1,n);
}
int getsum(int l,int r)
{
ll=l;rr=r;v=0;
rt->getsum(1,n);
return v;
}
int getmax(int l,int r)
{
ll=l;rr=r;v=0;
rt->getmax(1,n);
return v;
}
}a[gm];
int tp=-1;
inline void* node::operator new(size_t)
{
if(~tp) return F[tp--];
if(S==T)
{
S=new node[size];
T=S+size;
memset(S,0,sizeof(node[size]));
}
return S++;
}
inline void node::operator delete(void* p)
{
if(tp==size-1) ::delete (node*)p;
else
{
memset(p,0,sizeof(node));
F[++tp]=(node*)p;
}
}
int w[gm],c[gm];
struct e
{
int t;
e *n;
e(int t,e *n):t(t),n(n){}
}*f[gm];
#define link(a,b) f[a]=new e(b,f[a])
int sz[gm],fat[gm],son[gm],dpt[gm];
void dfs1(int x)
{
sz[x]=1;
int maxs=0,maxw=0,y;
for(e *i=f[x];i;i=i->n)
{
y=i->t;
if(fat[x]==y) continue;
fat[y]=x;
dpt[y]=dpt[x]+1;
dfs1(y);
sz[x]+=sz[y];
if(sz[y]>maxw)
maxs=y,maxw=sz[y];
}
son[x]=maxs;
}
int pos[gm],top[gm],ct=0;
void dfs2(int x)
{
pos[x]=++ct;
a[c[x]].change(ct,w[x]);
top[x]=x==son[fat[x]]?top[fat[x]]:x;
if(son[x]) dfs2(son[x]);
for(e *i=f[x];i;i=i->n)
{
if(fat[x]==i->t||son[x]==i->t) continue;
dfs2(i->t);
}
}
char o[3];
int z;
#define swap(x,y) z=x,x=y,y=z
int getsum(int x,int y)
{
int res=0;
ar& kre=a[c[x]];
while(top[x]!=top[y])
{
if(dpt[top[x]]<dpt[top[y]])
swap(x,y);
res+=kre.getsum(pos[top[x]],pos[x]);
x=fat[top[x]];
}
if(dpt[x]>dpt[y])
swap(x,y);
return res+kre.getsum(pos[x],pos[y]);
}
int getmax(int x,int y)
{
int res=0;
ar& kre=a[c[x]];
while(top[x]!=top[y])
{
if(dpt[top[x]]<dpt[top[y]])
swap(x,y);
res=__max(res,kre.getmax(pos[top[x]],pos[x]));
x=fat[top[x]];
}
if(dpt[x]>dpt[y])
swap(x,y);
return __max(res,kre.getmax(pos[x],pos[y]));
}
int main()
{
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++)
scanf("%d%d",w+i,c+i);
int s1,s2;
for(int i=1;i<n;i++)
scanf("%d%d",&s1,&s2),link(s1,s2),link(s2,s1);
int rt=114514%n+1;
dfs1(rt);
dfs2(rt);
for(int i=1;i<=q;i++)
{
scanf("%s%d%d",o,&s1,&s2);
switch(o[1])
{
case 'C':
a[c[s1]].free(pos[s1]);
//a[c[s1]].change(pos[s1],0);
c[s1]=s2;
a[s2].change(pos[s1],w[s1]);
break;
case 'W':
a[c[s1]].change(pos[s1],s2);
w[s1]=s2;
break;
case 'S':
printf("%d\n",getsum(s1,s2));
break;
case 'M':
printf("%d\n",getmax(s1,s2));
break;
}
}
return 0;
}