题目链接:
https://www.luogu.org/problemnew/show/P4427#sub
https://www.lydsy.com/JudgeOnline/problem.php?id=5293
http://exam.upc.edu.cn/problem.php?id=6744
题目描述
master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的 kk次方和,而且每次的 kk 可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了pupil,但pupil 并不会这么复杂的操作,你能帮他解决吗?
输入输出格式
输入格式:
第一行包含一个正整数 nn ,表示树的节点数。
之后 n-1n−1 行每行两个空格隔开的正整数 i, ji,j ,表示树上的一条连接点 ii 和点 jj 的边。
之后一行一个正整数 mm ,表示询问的数量。
之后每行三个空格隔开的正整数 i, j, ki,j,k ,表示询问从点 ii 到点 jj 的路径上所有节点深度的 kk 次方和。由于这个结果可能非常大,输出其对 998244353998244353 取模的结果。
树的节点从 11 开始标号,其中 11 号节点为树的根。
输出格式:
对于每组数据输出一行一个正整数表示取模后的结果。
输入输出样例
输入样例#1: 复制
5 1 2 1 3 2 4 2 5 2 1 4 5 5 4 45
输出样例#1: 复制
33 503245989
说明
样例解释
以下用 d (i)d(i) 表示第 ii 个节点的深度。
对于样例中的树,有 d (1) = 0, d (2) = 1, d (3) = 1, d (4) = 2, d (5) = 2d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2 。
因此第一个询问答案为 (2^5 + 1^5 + 0^5)\ mod\ 998244353 = 33(25+15+05) mod 998244353=33 ,第二个询问答案为 (2^{45} + 1^{45} + 2^{45})\ mod\ 998244353 = 503245989(245+145+245) mod 998244353=503245989 。
数据范围
对于 30\%30% 的数据, 1 \leq n,m \leq 1001≤n,m≤100 。
对于 60\%60% 的数据, 1 \leq n,m \leq 10001≤n,m≤1000 。
对于 100\%100% 的数据, 1 \leq n,m \leq 300000, 1 \leq k \leq 501≤n,m≤300000,1≤k≤50 。
另外存在5个不计分的hack数据
提示
数据规模较大,请注意使用较快速的输入输出方式。
[题意]
题意很明确 求 ,
[思路] 思路很明确, 数据量很大, 而且求深度有大量重复操作, 所以, 可以 求解 u,v得最近公共祖先 anc
在处理 deep[anc] -> deep[u] 和 deep[anc]---> deep[v] 就可以了.这个过程中,我们要求解一个深度得问题.
方法一: 预处理 深度的^k 次方, 还有前缀和, 通过前缀和 求解
方法二: 在dfs 时 ,处理深度. v = u + qpow(deep[v],k) ; 求解答案时, 用容斥原理 ans[u]+ans[v] - 2*ans[LCA[u,v]] + qpow(dee[LCA[u,v]],k)
详细 看代码
[代码君]
#include <bits/stdc++.h>
typedef long long ll;
const int maxn = 3e5+100;
const int mod = 998244353;
using namespace std;
inline ll qpow(ll a,ll n)
{
ll res = 1; for(;n;n>>=1) {if(n&1) res = res*a%mod; a =a*a%mod; }return res;
}
struct node{
int v,next;
}edge[maxn*2];
int head[maxn],cnt;
int fa[maxn];
int deep[maxn];// deep
int anc[maxn][25];
ll rlen[51][maxn];
void add(int u,int v)
{
edge[cnt]={v,head[u]};
head[u] = cnt++;
}
void dfs(int u)
{
for(int i = 1;i<=20;i++)
anc[u][i] = anc[anc[u][i-1]][i-1];
for(int i = head[u];i!=-1;i= edge[i].next)
{
int v = edge[i].v;
if(v!=anc[u][0])
{
anc[v][0] = u;
deep[v] =deep[u]+1;
for(int i = 0 ;i <=50;i++)
{
rlen[i][v] = (rlen[i][u]+qpow(deep[v],i))%mod;
}
dfs(v);
}
}
}
int LCA(int u,int v)
{
if(deep[u]<deep[v]) swap(u,v);
for(int i = 20;i>=0;i--)
{
if(deep[anc[u][i]] >=deep[v])
u =anc[u][i];
}
if(u==v) return u;
for(int i = 20;i>=0;i--)
{
if(anc[u][i]!=anc[v][i])
{
u = anc[u][i];
v = anc[v][i];
}
}
return anc[u][0];
}
void init()
{
cnt = 0;
memset(head,-1,sizeof(head));
}
int main()
{
init();
int n,m,x,y,u,v;
scanf("%d",&n);
for(int i = 1; i<n;i++)
{
scanf("%d %d",&u,&v);
add(u,v);
add(v,u);
}
for(int i =0;i<=50;i++)
rlen[i][1]=0;
deep[1] = 0;
dfs(1);
int q;
scanf("%d",&q);
int kk;
for(int i =1 ;i<=q;i++)
{
scanf("%d %d %d",&x,&y,&kk);
int temp =LCA(x,y);
ll ans = 0;
ans = (ans + rlen[kk][x]+rlen[kk][y]-2*rlen[kk][temp]+qpow(deep[temp],kk))%mod;
ans = (ans +mod) %mod;
printf("%lld\n",ans);
}
}
/*
5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45
*/