[GXOI/GZOI2019]旧词 解题报告
对于一棵 \(n\) 个节点的树,给出 \(m\) 次询问和常数 \(k\) ,每次给出 \(r,x\) ,求 \(\sum\limits_{i=1}^r depth(LCA(i,x))^k\) 。
\(n,m\le 5\times 10^4 , 1\le r,x\le n , k\le 10^9\)
如果有做过[LNOI2014]LCA,就很容易想出解法。
可以通过树上差分,点 \(x\) 的权值为 \(depth(x)^k-(depth(x)-1)^k\) ,那么对于一个点的深度的 \(k\) 次方就等于该点到根的路径上的所有点的权值之和。
可以将询问离线,按照 \(r\) 排序,依次修改 \(1\) 到 \(n\) 的点到根的点权,每次修改就是将 \(x\) 点加上 \(depth(x)^k-(depth(x)-1)^k\) ,那么对于相应的 \(r\) ,查询 \(x\) 到根的路径上的所有点的权值之和就是答案了。
可以通过树剖+线段树解决,效率 \(\mathcal{O(n\log^2 n)}\) ,精细实现可以做到 \(\mathcal{O(n\log n)}\) 。
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int M=5e5+5,JYY=998244353;
void swap(int &x,int &y){ x^=y^=x^=y; }
int min(int x,int y){ return x<y?x:y; }
int max(int x,int y){ return x>y?x:y; }
int n,q,K,mi[M],fa[M],de[M],Ans[M];
int read(){
int x=0,y=1;
char ch=getchar();
while(ch<'0'||ch>'9'){
if(ch=='-') y=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9'){
x=x*10+ch-'0';
ch=getchar();
}
return x*y;
}
int tot=0,first[M];
struct Edge{
int nxt,to;
}e[M<<1];
void add(int x,int y){
e[++tot].nxt=first[x];
first[x]=tot;
e[tot].to=y;
}
int num=0,dfn[M],pre[M],son[M],top[M],size[M];
struct Tree{
int sum,fz,lazy;
}tr[M<<2];
void pushup(int u){ tr[u].sum=(tr[u<<1].sum+tr[u<<1|1].sum)%JYY; }
void pushdown(int u){
if(!tr[u].lazy) return ;
(tr[u<<1].sum+=tr[u<<1].fz*tr[u].lazy%JYY)%=JYY;
(tr[u<<1].lazy+=tr[u].lazy)%=JYY;
(tr[u<<1|1].sum+=tr[u<<1|1].fz*tr[u].lazy%JYY)%=JYY;
(tr[u<<1|1].lazy+=tr[u].lazy)%=JYY;
tr[u].lazy=0;
}
void build(int u,int l,int r){
tr[u].sum=tr[u].lazy=0;
if(l==r) return (void)(tr[u].fz=(mi[de[pre[l]]]-mi[de[pre[l]]-1]+JYY)%JYY);
int mid=(l+r)>>1;
build(u<<1,l,mid),build(u<<1|1,mid+1,r);
tr[u].fz=(tr[u<<1].fz+tr[u<<1|1].fz)%JYY;
}
void change(int u,int l,int r,int L,int R,int x){
if(l>R||r<L) return ;
if(l>=L&&r<=R) return (void)((tr[u].sum+=x*tr[u].fz)%=JYY,(tr[u].lazy+=x)%=JYY);
pushdown(u);
int mid=(l+r)>>1;
change(u<<1,l,mid,L,R,x),change(u<<1|1,mid+1,r,L,R,x);
return (void)(pushup(u));
}
int query(int u,int l,int r,int L,int R){
if(l>R||r<L) return 0;
if(l>=L&&r<=R) return tr[u].sum;
pushdown(u);
int mid=(l+r)>>1;
return (query(u<<1,l,mid,L,R)+query(u<<1|1,mid+1,r,L,R))%JYY;
}
void dfs1(int u){
de[u]=de[fa[u]]+1;size[u]=1;
for(int i=first[u];i;i=e[i].nxt){
int v=e[i].to;
dfs1(v);size[u]+=size[v];
if(size[son[u]]<size[v]) son[u]=v;
}
}
void dfs2(int u,int tp){
dfn[u]=++num,pre[num]=u;top[u]=tp;
if(son[u]) dfs2(son[u],tp);
for(int i=first[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==son[u]) continue ;
dfs2(v,v);
}
}
void Change(int x,int d){
while(top[x]!=1){
change(1,1,n,dfn[top[x]],dfn[x],d);
x=fa[top[x]];
}
return (void)(change(1,1,n,dfn[1],dfn[x],d));
}
int Query(int x){
int res=0;
while(top[x]!=1){
res=(res+query(1,1,n,dfn[top[x]],dfn[x]))%JYY;
x=fa[top[x]];
}
return (res+query(1,1,n,dfn[1],dfn[x]))%JYY;
}
int qpow(int x,int y){
int res=1;
for(;y;x=x*x%JYY,y>>=1) if(y&1) res=res*x%JYY;
return res;
}
struct Ques{ int r,x,id; }Q[M];
bool cmp(Ques x,Ques y){ return x.r<y.r; }
void solve(){
n=read(),q=read(),K=read()%(JYY-1);mi[0]=0;mi[1]=1;
for(int i=2;i<=n;i++) fa[i]=read(),add(fa[i],i),mi[i]=qpow(i,K);
dfs1(1);dfs2(1,1);build(1,1,n);
// for(int i=1;i<=n;i++) printf("%lld ",(mi[de[i]]-mi[de[i]-1]+JYY)%JYY);printf("\n");
for(int i=1;i<=q;i++){
int r=read(),x=read();
Q[i]=(Ques){r,x,i};
}
sort(Q+1,Q+q+1,cmp);
// for(int i=1;i<=2*q;i++){
// printf("%d %d %d %d\n",Q[i].x,Q[i].z,Q[i].z,Q[i].id);
// }
for(int i=1,now=0;i<=q;i++){
while(now<Q[i].r) now++,Change(now,1);
Ans[Q[i].id]=Query(Q[i].x);
}
for(int i=1;i<=q;i++) printf("%lld\n",Ans[i]);
}
signed main(){
solve();
}