题目链接:Bzoj : https://www.lydsy.com/JudgeOnline/problem.php?id=5293
落谷 : https://www.luogu.org/problemnew/show/P4427
题目描述
master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的 kk次方和,而且每次的 kk 可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了pupil,但pupil 并不会这么复杂的操作,你能帮他解决吗?
输入输出格式
输入格式:
第一行包含一个正整数 nn ,表示树的节点数。
之后 n-1n−1 行每行两个空格隔开的正整数 i, ji,j ,表示树上的一条连接点 ii 和点 jj 的边。
之后一行一个正整数 mm ,表示询问的数量。
之后每行三个空格隔开的正整数 i, j, ki,j,k ,表示询问从点 ii 到点 jj 的路径上所有节点深度的 kk 次方和。由于这个结果可能非常大,输出其对 998244353998244353 取模的结果。
树的节点从 11 开始标号,其中 11 号节点为树的根。
输出格式:
对于每组数据输出一行一个正整数表示取模后的结果。
输入输出样例
输入样例#1:
5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45
输出样例#1:
33
503245989
说明
Ps:加了输入挂应该可以快点。Bzoj有大佬2s就不过来,不知道是什么骚操作。
题意很简单,就是一棵树,问你 I 点到 J 点之间的距离和,对于经过某个点提供的贡献就是 当前节点在树中的层数的K次方。
大概如下丑图:
假如求3到6之间的距离和 k是2的话 就等于2^2+1^2+0^2+1^2+2^2+3^3 这样一个最短的路径的和。
所以预处理下30W*50的这样的一个次方的前缀和。
然后就是dfs求节点的深度,倍增求一下公共祖先的深度,然后就XJB求一下就行了,会爆炸int。
#include<bits/stdc++.h>
using namespace std;
const int maxn = 300005;
const long long mod = 998244353;
struct node{
int u;
int v;
int next;
}no[600050];
int head[maxn];
int dep[maxn];
int anc[maxn][25];
long long dis[maxn][51];
int cnt,k;
long long Ans;
void init()//预处理
{
for(long long i=1;i<=300000;i++)
{
dis[i][0]=1;
dis[i][1]=i;
for(int j=2;j<=50;j++)
dis[i][j]=(dis[i][j-1]*i)%mod;
}
for(int i=2;i<=300000;i++)
for(int j=1;j<=50;j++)
dis[i][j]=(dis[i][j]+dis[i-1][j])%mod;
}
void add(int u,int v)
{
no[cnt]={u,v,head[u]};
head[u]=cnt++;
}
void dfs(int u)//预处理倍增
{
for(int i=1;i<=20;i++)
anc[u][i]=anc[anc[u][i-1]][i-1];
for(int i=head[u];i!=-1;i=no[i].next)
{
int v=no[i].v;
if(v!=anc[u][0])
{
dep[v]=dep[u]+1;
anc[v][0]=u;
dfs(v);
}
}
}
int LCA(int u,int v)//求公共祖先节点
{
if(dep[u] < dep[v])
swap(u,v);
for(int i=20;i>=0;i--)
if(dep[anc[u][i]] >= dep[v])
u=anc[u][i];
if(u==v)
return u;
for(int i=20;i>=0;i--)
{
if(anc[u][i] != anc[v][i])
{
u=anc[u][i];
v=anc[v][i];
}
}
return anc[u][0];
}
int main()
{
init();
cnt=0;
memset(head,-1,sizeof(head));
int n,u,v;
scanf("%d",&n);
for(int i=0;i<n-1;i++)
{
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
dep[1]=1;
dfs(1);
int m;
scanf("%d",&m);
while(m--)
{
Ans=0;
scanf("%d%d%d",&u,&v,&k);
int temp=LCA(u,v);
int h=dep[u]-1;
if(dep[temp]-2>=0)//一边的和
Ans=Ans+(dis[h][k]-dis[dep[temp]-2][k]+mod)%mod;
else
Ans=Ans+(dis[h][k]-dis[0][k]+mod)%mod;
h=dep[v]-1;
if(dep[temp]-1>=0)//另外一边的和
Ans=Ans+(dis[h][k]-dis[dep[temp]-1][k]+mod)%mod;
else
Ans=Ans+(dis[h][k]-dis[0][k]+mod)%mod;
printf("%lld\n",Ans%mod);
}
return 0;
}