这题就是在树上维护四个操作:查询树链上所有节点值的和、树链上所有节点值变成同一个数、查询子树上所有节点值的和、子树上所有节点值变成同一个数.
树链剖分+线段树可以支持上面所有操作,所以只要树链剖分+线段树就可以了.
AC code:
#include <cstdio>
#include <vector>
using namespace std;
const int N=100010;
int n,q,cnt,tot;
int siz[N],pre[N],num[N],top[N];
vector<int> G[N];
struct nod{
int l,r,sum,tag;
nod *lc,*rc;
}pool[N<<2];
struct Segtree{
nod *root;
Segtree(){
build(&root,1,n);
}
void pushdown(nod *p){
if(p->tag==-1) return ;
p->sum=p->tag*(p->r-p->l+1);
if(p->lc!=NULL) p->lc->tag=p->tag,p->rc->tag=p->tag;
p->tag=-1;
}
void build(nod **p,int L,int R){
*p=&pool[tot++];
(*p)->l=L;(*p)->r=R;(*p)->tag=-1;
if(L==R){
(*p)->sum=1;
return ;
}
int M=(L+R)>>1;
build(&(*p)->lc,L,M);
build(&(*p)->rc,M+1,R);
(*p)->sum=(*p)->lc->sum+(*p)->rc->sum;
}
void change(nod *p,int L,int R,int v){
pushdown(p);
if(p->l==L&&p->r==R){
p->tag=v;
return ;
}
int M=(p->l+p->r)>>1;
if(R<=M) change(p->lc,L,R,v);
else if(L>M) change(p->rc,L,R,v);
else{
change(p->lc,L,M,v);
change(p->rc,M+1,R,v);
}
pushdown(p->lc);pushdown(p->rc);
p->sum=p->lc->sum+p->rc->sum;
}
int getsum(nod *p,int L,int R){
pushdown(p);
if(p->l==L&&p->r==R) return p->sum;
int M=(p->l+p->r)>>1;
if(R<=M) return getsum(p->lc,L,R);
else if(L>M) return getsum(p->rc,L,R);
else return getsum(p->lc,L,M)+getsum(p->rc,M+1,R);
}
};
void dfs1(int x,int pr){
pre[x]=pr;siz[x]=1;
for(int i=0;i<(int)G[x].size();i++){
int y=G[x][i];
if(y==pr) continue;
dfs1(y,x);
siz[x]+=siz[y];
}
}
void dfs2(int x,int tp){
top[x]=tp;num[x]=++cnt;
int mx=0,y=-1;
for(int i=0;i<(int)G[x].size();i++){
int z=G[x][i];
if(z==pre[x]) continue;
if(siz[z]>mx){
y=z;
mx=siz[z];
}
}
if(y!=-1) dfs2(y,tp);
for(int i=0;i<(int)G[x].size();i++){
int z=G[x][i];
if(z==y||z==pre[x]) continue;
dfs2(z,z);
}
}
int main(){
scanf("%d",&n);
Segtree T;
for(int i=1;i<n;i++){
int v;
scanf("%d",&v);
G[v].push_back(i);
G[i].push_back(v);
}
dfs1(0,-1);
dfs2(0,0);
scanf("%d",&q);
for(int i=1;i<=q;i++){
int x;
char s[11];
scanf("%s%d",s,&x);
if(s[0]=='i'){
int sum=0;
for(int j=x;j!=-1;j=pre[top[j]]){
sum+=T.getsum(T.root,num[top[j]],num[j]);
T.change(T.root,num[top[j]],num[j],0);
}
printf("%d\n",sum);
}
else{
printf("%d\n",siz[x]-T.getsum(T.root,num[x],num[x]+siz[x]-1));
T.change(T.root,num[x],num[x]+siz[x]-1,1);
}
}
return 0;
}