[BJOI2018]求和
题目概述
题目描述
\(master\) 对树上的求和非常感兴趣。
他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的\(k\) 次方和,而且每次的\(k\) 可能是不同的。
此处节点深度的定义是这个节点到根的路径上的边数。
他把这个问题交给了\(pupil\),但\(pupil\) 并不会这么复杂的操作,你能帮他解决吗?
输入输出格式
输入格式
第一行包含一个正整数\(n\),表示树的节点数。
之后\(n-1\) 行每行两个空格隔开的正整数\(i, j\),表示树上的一条连接点\(i\) 和点\(j\) 的边。
之后一行一个正整数\(m\),表示询问的数量。
之后每行三个空格隔开的正整数\(i, j, k\),表示询问从点\(i\) 到点\(j\) 的路径上所有节点深度的\(k\) 次方和。
由于这个结果可能非常大,输出其对\(998244353\) 取模的结果。
树的节点从\(1\) 开始标号,其中\(1\) 号节点为树的根。
输出格式
对于每组数据输出一行一个正整数表示取模后的结果。
输入输出样例
输入样例 #1
5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45
输出样例 #1
33
503245989
样例解释
以下用\(d (i)\) 表示第\(i\) 个节点的深度。 对于样例中的树,有\(d (1) = 0, d (2) = 1, d (3) = 1, d (4) = 2, d (5) = 2\)。
因此第一个询问答案为\((2^5 + 1^5 + 0^5)\ mod\ 998244353 = 33\)
第二个询问答案为\((2^{45} + 1^{45} + 2^{45})\ mod\ 998244353 = 503245989\)。
数据范围
对于\(30\%\) 的数据,\(1 \leq n,m \leq 100\)。
对于\(60\%\) 的数据,\(1 \leq n,m \leq 1000\)。
对于\(100\%\) 的数据,\(1 \leq n,m \leq 300000, 1 \leq k \leq 50\)。
友情提示
数据规模较大,请注意使用较快速的输入输出方式。
解题报告
题意理解
- 要你求出一条路经上,每一个点的深度的\(k\)次方之和
- $ 1\ le k \le 50$
算法解析
首先看到下面,这些话,你就可以判断本题目算法为最近公共祖先
- 一棵树上,大量两点路径查询
- 没有任何修改操作。
- 树上节点很多,要求复杂度不高
综上所述,我们发现这道题目满足所有的要求,因此我们可以推断出,本题目使用最近公共祖先。
我们再来分析这道题目如何使用LCA算法。
我们观察数据范围,得到\(k\)的值域很小。
因此我们不妨开一个长度为\(50\)大小的数组,存储每一个\(k\)对应的树。
然后我们可以通过树上差分算法,解决查询问题。
这里就不画图,如果要看的话,就瞅瞅我的树上差分专题吧,上面有这道题目的图,或者康康我的讲课视频,链接
代码解析
#include <bits/stdc++.h>
using namespace std;
const int N=300100,Mod=998244353;
int fa[N][21],n,m;
long long sum[51][N],deep[N],power[51];
vector<int> g[N];
void dfs(int x,int s)
{
for(int i=1; i<=20; i++)
fa[x][i]=fa[fa[x][i-1]][i-1];
for(int y:g[x])
{
if (y==s)
continue;
deep[y]=deep[x]+1;
fa[y][0]=x;
for(int k=1; k<=50; k++)
power[k]=power[k-1]*deep[y]%Mod;
for(int k=1; k<=50; k++)
sum[k][y]=(power[k]+sum[k][x])%Mod;
dfs(y,x);
}
}
inline int Lca(int a,int b)
{
if (deep[a]<deep[b])
swap(a,b);
for(int i=20; i>=0; i--)
if (deep[fa[a][i]]>=deep[b])
a=fa[a][i];
if (a==b)
return a;
for(int i=20; i>=0; i--)
if (fa[a][i]!=fa[b][i])
a=fa[a][i],b=fa[b][i];
return fa[a][0];
}
inline void init()
{
// freopen("data.in","r",stdin);
// freopen("a.out","w",stdout);
scanf("%d",&n);
for(int i=1; i<n; i++)
{
int a,b;
scanf("%d%d",&a,&b);
g[a].push_back(b);
g[b].push_back(a);
}
power[0]=1;
dfs(1,1);
scanf("%d",&m);
for(int i=1; i<=m; i++)
{
int k,a,b;
scanf("%d%d%d",&a,&b,&k);
int LCA=Lca(a,b);
long long ans=sum[k][a]+Mod+sum[k][b]+Mod-sum[k][fa[LCA][0]]-sum[k][LCA];
ans%=Mod;
printf("%lld\n",ans);
}
}
signed main()
{
init();
return 0;
}