题目描述
master对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的k次方和,而且每次的k可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了pupil,但pupil并不会这么复杂的操作,你能帮他解决吗?
输入
第一行包含一个正整数n,表示树的节点数。
之后n−1行每行两个空格隔开的正整数i,j,表示树上的一条连接点i和点j的边。
之后一行一个正整数m,表示询问的数量。
之后每行三个空格隔开的正整数i,j,k,表示询问从点i到点j的路径上所有节点深度的k次方和。由于这个结果可能非常大,输出其对998244353取模的结果。
树的节点从1开始标号,其中1号节点为树的根。
输出
对于每组数据输出一行一个正整数表示取模后的结果。
样例输入
5 1 2 1 3 2 4 2 5 2 1 4 5 5 4 45
样例输出
33 503245989
提示
以下用d(i)表示第i个节点的深度。
对于样例中的树,有d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2。
因此第一个询问答案为(25+15+05) mod 998244353=33,第二个询问答案为(245+145+245) mod 998244353=503245989。
对于30%的数据,1≤n,m≤100;
对于60%的数据,1≤n,m≤1000;
对于100%的数据,1≤n,m≤300000,1≤k≤50。
思路:利用前向星见图bfs过程中,记录深度维护val[i][j],i表th[]示当前结点,j表示k,val[i][j]=val[fa[i]][j]+cal(de[i],j),这样可以维护一个前缀和,后面给出查询的时候,先用倍增法求出lca,然后val[a][k]+val[b][k]-2*val[lca][k]+cal(lca,k)就是这条路径上的深度K次方和。
代码:(照着狂兵模板改的):)
#include<bits/stdc++.h>
#include<iostream>
#include<cstring>
#include<string>
#include<queue>
#include<map>
using namespace std;
const int maxn = 300010;
const int DEG = 20; //树的最多层数
const int MOD=998244353;
struct Edge { //边
int to, next;
} edge[maxn * 2];
int head[maxn], tot; //前向星链表,head[]为链头,tot为边的总数
long long sum_path; //sum_path为全局变量
void addedge(int u, int v) { //前向星链表添加边
edge[tot].to = v;
edge[tot].next = head[u];
head[u] = tot++;
}
void init() { //初始化
tot = 0;
memset(head, -1, sizeof(head));
}
long long k;
long long cal(long long dep,long long k){
long long res=1;
while(k){
if(k%2==1) res=(res*dep)%MOD;
dep=(dep*dep)%MOD;
k/=2;
}
return res;
}
long long val[maxn][55];
int fa[maxn][DEG];//fa[i][j]表示结点i的第2^j个祖先
long long deg[maxn];//深度数组
void BFS(int root) { //bds预处理出每一个结点u的fa[u][i]以及深度
queue<int>que;
deg[root] = 0;
fa[root][0] = root;
que.push(root);
for(int i=1;i<=50;i++) val[root][i]=0;
while (!que.empty()) {
int tmp = que.front();
que.pop();
for (int i = 1; i < DEG; i++){ //u到2^i个父亲可以拆成先到2^(i-1)父亲处,再从该处到相应的2^(i-1)父亲处
fa[tmp][i] = fa[fa[tmp][i - 1]][i - 1];
}
for (int i = head[tmp]; i != -1; i = edge[i].next) {
int v = edge[i].to;
if (v == fa[tmp][0])continue;
deg[v] = deg[tmp] + 1;
fa[v][0] = tmp;
for(int i=1;i<=50;i++) val[v][i]=(val[tmp][i]+cal(deg[v],i)%MOD)%MOD;
que.push(v);
}
}
}
int LCA(int u, int v) { //求u和v的最近公共祖先,以及路径上的权值总和
if (deg[u] > deg[v])swap(u, v);
int hu = deg[u], hv = deg[v];
int tu = u, tv = v;
sum_path = 0;
for (int det = hv - hu, i = 0; det; det >>= 1, i++) //先让深度较大的tv往上跳跃到和tu相同的高度
if (det & 1){
tv = fa[tv][i];
}
if (tu == tv)return tu; //如果此时在同一结点,说明已经到达最近公共祖先,返回最近最近公共祖先
for (int i = DEG - 1; i >= 0; i--) { //一起往上跳,不断逼近最近公共祖先,但不到达最近公共祖先
if (fa[tu][i] == fa[tv][i])
continue;
tu = fa[tu][i];
tv = fa[tv][i];
}
if (fa[tu][0] == fa[tv][0]){ //tu和tv再往上一个单位就是最近公共祖先了
return fa[tu][0];
}
else //否则没有最近公共祖先,返回-1
return -1;
}
bool flag[maxn];
char s[10];
int main() {
int m, u, v, w, q;
scanf("%d",&m);
init(); //初始化
for (int i = 1; i < m; i++){
scanf("%d%d", &u, &v);
addedge(u, v);
addedge(v, u);
}
BFS(1); //以结点1开始bfs遍历建树
scanf("%d", &q);
for (int i = 1; i <= q; i++){
scanf("%d%d%lld", &u, &v,&k);
int lca=LCA(u, v); //u和v的最近公共祖先
sum_path=(val[u][k]+val[v][k]-(2*val[lca][k])%MOD+cal(deg[lca],k)+2*MOD)%MOD;
printf("%lld\n", sum_path); //sum_path为u到v的路径权值总和
}
return 0;
}
/*
5
1 2
2 3
3 4
3 5
3
3 4 2
3 5 2
4 5 2
*/