离线一下,按x递增处理询问
我们先对这棵树树剖一下,然后对于每一点维护非重儿子的所在子树的点的个数*ksm(dep,k)然后只需要再套个线段树就行了
#include<bits/stdc++.h>
#define pb(x) push_back(x)
#define mk(x,y) make_pair(x,y)
#define ll long long
using namespace std;
const int mod=998244353;
const int N=5e4+10;
int sum[N<<2],tree[N<<2];
int size[N],dep[N],f[N],son[N],top[N],dfn[N],dnow[N],ans[N];
vector<int>e[N];
vector<pair<int,int> >q[N];
int n,k,q_sum,dfstime,x,y;
int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
int mul(int x,int y){return (ll)x*y%mod;}
int ksm(int x,int y){
int ans=1;
for (;y;y>>=1,x=mul(x,x)) if (y&1) ans=mul(ans,x);
return ans;
}
void add_edge(int x,int y){e[x].pb(y);}
int dfs1(int now,int fa,int d){
dep[now]=ksm(d,k); size[now]=1; f[now]=fa;
for (int i=0;i<e[now].size();i++){
int u=e[now][i];
size[now]+=dfs1(u,now,d+1);
if (size[son[now]]<size[u]) son[now]=u;
}
return size[now];
}
void dfs2(int now,int t){
dfn[now]=++dfstime; top[now]=t;
if (son[now]) dfs2(son[now],t);
for (int i=0;i<e[now].size();i++){
int u=e[now][i];
if (u==son[now]) continue;
dfs2(u,u);
}
}
void update(int k){
tree[k]=tree[k<<1|1]+tree[k<<1];
}
void change1(int l,int r,int k,int x){
if (l==r) tree[k]++;
else {
int mid=(l+r)/2;
if (x<=mid) change1(l,mid,k<<1,x); else change1(mid+1,r,k<<1|1,x);
update(k);
}
}
void change(int l,int r,int k,int x){
if (l==r) sum[k]=add(sum[k],dnow[l]);
else {
int mid=(l+r)/2;
if (x<=mid) change(l,mid,k<<1,x); else change(mid+1,r,k<<1|1,x);
sum[k]=(sum[k<<1]+sum[k<<1|1])%mod;
}
}
int findans(int l,int r,int k,int ql,int qr){
if (l>=ql&&r<=qr) return tree[k];
else {
int mid=(l+r)/2;
int s=0;
if (ql<=mid) s+=findans(l,mid,k<<1,ql,qr);
if (qr>mid) s+=findans(mid+1,r,k<<1|1,ql,qr);
return s;
}
}
int findans1(int l,int r,int k,int ql,int qr){
if (l>=ql&&r<=qr) return sum[k];
else {
int mid=(l+r)/2;
int s=0;
if (ql<=mid) s=add(s,findans1(l,mid,k<<1,ql,qr));
if (qr>mid) s=add(s,findans1(mid+1,r,k<<1|1,ql,qr));
return s;
}
}
void solve(int x){
int t=x;
change1(1,n,1,dfn[x]);
change(1,n,1,dfn[x]);
while (top[x]!=1){x=top[x];x=f[x];change(1,n,1,dfn[x]);}
x=t;
for (int i=0;i<q[x].size();i++) {
int now=q[x][i].first;
int s=mul(findans(1,n,1,dfn[now],dfn[now]+size[now]-1),dep[now]);
while (now>1){
s=add(s,mul(dec(findans(1,n,1,dfn[f[now]],dfn[f[now]]+size[f[now]]-1),findans(1,n,1,dfn[now],dfn[now]+size[now]-1)),dep[f[now]]));
// cout << x << ' ' << f[now] << ' ' << findans(1,n,1,dfn[f[now]],dfn[f[now]]+size[f[now]]-1) << endl;
now=f[now];
if (now!=top[now]) s=add(s,findans1(1,n,1,dfn[top[now]],dfn[f[now]]));
now=top[now];
}
ans[q[x][i].second]=s;
}
}
int main(){
scanf("%d%d%d",&n,&q_sum,&k);
for (int i=2;i<=n;i++) scanf("%d",&x),add_edge(x,i);
dfs1(1,0,1); dfs2(1,1);
for (int i=1;i<=q_sum;i++){
scanf("%d%d",&x,&y); q[x].pb(mk(y,i));
}
for (int i=1;i<=n;i++) dnow[dfn[i]]=dep[i];
for (int i=1;i<=n;i++) solve(i);
for (int i=1;i<=q_sum;i++) printf("%d\n",ans[i]);
}