[搜索]求和VII
题目描述
master对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的k次方和,而且每次的k可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了pupil,但pupil并不会这么复杂的操作,你能帮他解决吗?
输入
第一行包含一个正整数n,表示树的节点数。
之后n−1行每行两个空格隔开的正整数i,j,表示树上的一条连接点i和点j的边。
之后一行一个正整数m,表示询问的数量。
之后每行三个空格隔开的正整数i,j,k,表示询问从点i到点j的路径上所有节点深度的k次方和。由于这个结果可能非常大,输出其对998244353取模的结果。
树的节点从1开始标号,其中1号节点为树的根。
之后n−1行每行两个空格隔开的正整数i,j,表示树上的一条连接点i和点j的边。
之后一行一个正整数m,表示询问的数量。
之后每行三个空格隔开的正整数i,j,k,表示询问从点i到点j的路径上所有节点深度的k次方和。由于这个结果可能非常大,输出其对998244353取模的结果。
树的节点从1开始标号,其中1号节点为树的根。
输出
对于每组数据输出一行一个正整数表示取模后的结果。
样例输入
5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45
样例输出
33
503245989
提示
以下用d(i)表示第i个节点的深度。
对于样例中的树,有d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2。
因此第一个询问答案为(25+15+05) mod 998244353=33,第二个询问答案为(245+145+245) mod 998244353=503245989。
对于30%的数据,1≤n,m≤100;
对于60%的数据,1≤n,m≤1000;
对于100%的数据,1≤n,m≤300000,1≤k≤50。
思路:根据询问的x,y的depth[x],depth[y]可以知道要求的和,在线求是不可能的,预处理搞一搞
AC代码:
(更新...原来这是LCA,我竟然在还没学LCA的情况下做出了一道LCA的题...
#include <iostream> #include<cstdio> #include<vector> #define mod 998244353 typedef long long ll; using namespace std; int n,max_depth; vector<int> edge[300010]; vector<int> sons[300010]; int fa[300010]; ll sum[300010][51]; int depth[300010]; inline ll qpow(ll a,ll b){ ll ret=1; while(b){ if(b&1) ret=ret*a%mod; a=a*a%mod; b>>=1; } return ret%mod; } inline void init(){ for(int k=1;k<=50;k++){ sum[0][k]=0; for(int i=1;i<=max_depth;i++){ sum[i][k]=(sum[i-1][k]+qpow(i,k)%mod)%mod; } } } inline void dfs_build(int x){ for(int i=0;i<(int)edge[x].size();i++){ int v=edge[x][i]; if(depth[v]==-1){ depth[v]=depth[x]+1; max_depth=max(max_depth,depth[v]); sons[x].push_back(v); fa[v]=x; dfs_build(v); } } } inline int check(int x,int y){ int tmp=x; while(depth[tmp]!=depth[y]){ tmp=fa[tmp]; } if(tmp==y) return -1; else{ int fa1=tmp,fa2=y; while(fa1!=fa2){ fa1=fa[fa1]; fa2=fa[fa2]; } return depth[fa1]; } } int main() { scanf("%d",&n); for(int i=1;i<=n-1;i++){ int x,y;scanf("%d%d",&x,&y); edge[x].push_back(y); edge[y].push_back(x); } for(int i=0;i<=n;i++) depth[i]=-1; depth[1]=0; dfs_build(1); init(); int m;scanf("%d",&m); while(m--){ int x,y,k;scanf("%d%d%d",&x,&y,&k); if(depth[x]>depth[y]) swap(x,y); ll ans=0; if(depth[x]==0) ans=sum[depth[y]][k]; else{ ans=(sum[depth[y]][k]-sum[depth[x]-1][k]+mod)%mod; int tmp=check(y,x); if(tmp!=-1){ if(tmp==0) ans=(sum[depth[y]][k]+sum[depth[x]][k])%mod; else ans=(sum[depth[y]][k]-sum[tmp-1][k]+mod+sum[depth[x]][k]-sum[tmp][k]+mod)%mod; } } printf("%lld\n",ans); } return 0; }
转载请注明出处:https://www.cnblogs.com/lllxq/