[BJOI2018]求和(树链剖分)
题目描述
master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的 kkk 次方和,而且每次的 kkk 可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了pupil,但pupil 并不会这么复杂的操作,你能帮他解决吗?
输入输出格式
输入格式:第一行包含一个正整数 nnn ,表示树的节点数。
之后 n−1n-1n−1 行每行两个空格隔开的正整数 i,ji, ji,j ,表示树上的一条连接点 iii 和点 jjj 的边。
之后一行一个正整数 mmm ,表示询问的数量。
之后每行三个空格隔开的正整数 i,j,ki, j, ki,j,k ,表示询问从点 iii 到点 jjj 的路径上所有节点深度的 kkk 次方和。由于这个结果可能非常大,输出其对 998244353998244353998244353 取模的结果。
树的节点从 111 开始标号,其中 111 号节点为树的根。
输出格式:对于每组数据输出一行一个正整数表示取模后的结果。
输入输出样例
说明
样例解释
以下用 d(i)d (i)d(i) 表示第 iii 个节点的深度。
对于样例中的树,有 d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2d (1) = 0, d (2) = 1, d (3) = 1, d (4) = 2, d (5) = 2d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2 。
因此第一个询问答案为 (25+15+05) mod 998244353=33(2^5 + 1^5 + 0^5)\ mod\ 998244353 = 33(25+15+05) mod 998244353=33 ,第二个询问答案为 (245+145+245) mod 998244353=503245989(2^{45} + 1^{45} + 2^{45})\ mod\ 998244353 = 503245989(245+145+245) mod 998244353=503245989 。
数据范围
对于 30%30\%30% 的数据, 1≤n,m≤1001 \leq n,m \leq 1001≤n,m≤100 。
对于 60%60\%60% 的数据, 1≤n,m≤10001 \leq n,m \leq 10001≤n,m≤1000 。
对于 100%100\%100% 的数据, 1≤n,m≤300000,1≤k≤501 \leq n,m \leq 300000, 1 \leq k \leq 501≤n,m≤300000,1≤k≤50 。
另外存在5个不计分的hack数据
提示
数据规模较大,请注意使用较快速的输入输出方式。
#include<cstdio> #include<cstring> #include<cmath> #include<iostream> #include<algorithm> #define ll long long using namespace std; const int N=300001; const int mod=998244353; int n,m,dep[51][N],fa[N],size[N],son[N],maxx,fid[N]; int l[N],tot,head[N*2],num,top[N]; struct node{ int to,next; }e[N*2]; struct tr{ int sum[N*4]; }t[51]; int read() { int x=0,w=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')w=-1;ch=getchar();} while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar(); return x*w; } void add(int from,int to) { num++; e[num].to=to; e[num].next=head[from]; head[from]=num; } void dfs1(int x) { size[x]=1; for(int i=head[x];i;i=e[i].next) { int v=e[i].to; if(!dep[1][v]&&v!=1) { dep[1][v]=dep[1][x]+1;fa[v]=x; maxx=max(maxx,dep[1][v]); dfs1(v); size[x]+=size[v]; if(size[v]>size[son[x]])son[x]=v; } } } void dfs2(int x,int tp) { l[x]=++tot;top[x]=tp;fid[tot]=x; if(son[x])dfs2(son[x],tp); for(int i=head[x];i;i=e[i].next) { int v=e[i].to; if(v!=fa[x]&&v!=son[x]) dfs2(v,v); } } void init() { n=read(); for(int i=1;i<n;i++) { int x=read(),y=read(); add(x,y);add(y,x); } dfs1(1); dep[1][1]=0;fa[1]=1; dfs2(1,1); } void build(int rt,int root,int left,int right) { if(left==right){ t[rt].sum[root]=dep[rt][fid[left]]; return ; } int mid=(left+right)>>1; build(rt,root<<1,left,mid); build(rt,root<<1|1,mid+1,right); t[rt].sum[root]=(ll)(t[rt].sum[root<<1]+t[rt].sum[root<<1|1])%mod; if(t[rt].sum[root]>=mod)t[rt].sum[root]-=mod; } int query(int rt,int root,int left,int right,int L,int R) { if(left>R||right<L)return 0; if(L<=left&&right<=R)return t[rt].sum[root]%mod; int mid=(left+right)>>1; ll a=0,b=0; if(mid>=L)a=query(rt,root<<1,left,mid,L,R)%mod; if(mid<R) b=query(rt,root<<1|1,mid+1,right,L,R)%mod; return (a+b)%mod; } int cal(int rt,int x,int y) { ll ans=0; int fx=top[x],fy=top[y]; while(fx!=fy) { if(dep[fx]<dep[fy]){swap(fx,fy);swap(x,y);} ans+=query(rt,1,1,n,l[fx],l[x]); ans%=mod; x=fa[fx],fx=top[x]; } if(l[x]>l[y])swap(x,y); ans+=query(rt,1,1,n,l[x],l[y]); ans%=mod; return ans; } void perp() { for(int i=2;i<=50;i++) for(int j=1;j<=n;j++) {dep[i][j]=((ll)dep[i-1][j]%mod*(ll)dep[1][j]%mod)%mod;} for(int i=1;i<=50;i++) build(i,1,1,n); } void solve() { m=read(); for(int i=1;i<=m;i++) {int x=read(),y=read(),k=read(); printf("%d\n",cal(k,x,y));} } int main() { init(); perp(); solve(); return 0; }