luogu P5305 [GXOI/GZOI2019]旧词
先考虑\(k=1\),一个点的深度就是到根节点的路径上的点的个数,所以\(lca(x,y)\)的深度就是\(x\)和\(y\)到根路径的交集路径上的点的个数,那么对于一个询问,我们可以对每个点\(i\le x\),把\(1\)到\(i\)路径上所有点\(+1\),然后查询\(1\)到\(y\)的点权和就行了.现在有多组询问,路径修改可以树剖+在以\(dfn\)序为下标的线段树上修改,然后套可持久化线段树保存每个\(i\)的线段树状态,每次在对应线段树上区间查询即可.可持久化线段树的区间修改可以参考代码
然后\(k>1\),其实可以进行差分,即每次深度为\(dep\)的点加上\(dep^k-(dep-1)^k\),这样深度为\(dep\)的点到根的权值和就是\(dep^k\)
// luogu-judger-enable-o2
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<vector>
#include<cmath>
#include<ctime>
#include<queue>
#include<map>
#include<set>
#define LL long long
#define db double
using namespace std;
const int N=50000+10,mod=998244353;
int rd()
{
int x=0,w=1;char ch=0;
while(ch<'0'||ch>'9'){if(ch=='-') w=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
return x*w;
}
int fpow(int a,int b){int an=1;while(b){if(b&1) an=1ll*an*a%mod;a=1ll*a*a%mod,b>>=1;} return an;}
int to[N],nt[N],hd[N],tot=1;
void add(int x,int y){++tot,to[tot]=y,nt[tot]=hd[x],hd[x]=tot;}
int n,q,kk,a[N],ps[N];
int s[N*100],tg[N*100],ch[N*100][2],rt[N],tt;
#define mid ((l+r)>>1)
void modif(int &o,int l,int r,int ll,int rr)
{
++tt,s[tt]=s[o],tg[tt]=tg[o],ch[tt][0]=ch[o][0],ch[tt][1]=ch[o][1],o=tt;
s[o]=(1ll*s[o]+ps[min(r,rr)]-ps[max(l,ll)-1]+mod)%mod;
if(ll<=l&&r<=rr){++tg[o];return;}
if(ll<=mid) modif(ch[o][0],l,mid,ll,rr);
if(rr>mid) modif(ch[o][1],mid+1,r,ll,rr);
}
int quer(int o,int l,int r,int ll,int rr)
{
if(!o) return 0;
if(ll<=l&&r<=rr) return s[o];
int an=1ll*tg[o]*(ps[min(r,rr)]-ps[max(l,ll)-1]+mod)%mod;
if(ll<=mid) an=(an+quer(ch[o][0],l,mid,ll,rr))%mod;
if(rr>mid) an=(an+quer(ch[o][1],mid+1,r,ll,rr))%mod;
return an;
}
int fa[N],de[N],sz[N],hs[N],top[N],dfn[N],ti;
void dfs1(int x)
{
sz[x]=1;
for(int i=hd[x];i;i=nt[i])
{
int y=to[i];
de[y]=de[x]+1,dfs1(y);
sz[x]+=sz[y],hs[x]=sz[hs[x]]>sz[y]?hs[x]:y;
}
}
void dfs2(int x,int ntp)
{
top[x]=ntp,dfn[x]=++ti,ps[ti]=a[de[x]];
if(hs[x]) dfs2(hs[x],ntp);
for(int i=hd[x];i;i=nt[i])
{
int y=to[i];
if(y==hs[x]) continue;
dfs2(y,y);
}
}
int main()
{
n=rd(),q=rd(),kk=rd();
for(int i=1;i<=n;++i) a[i]=fpow(i,kk);
for(int i=n;i;--i) a[i]=(a[i]-a[i-1]+mod)%mod;
for(int i=2;i<=n;++i) add(fa[i]=rd(),i);
de[1]=1,dfs1(1),dfs2(1,1);
for(int i=1;i<=n;++i) ps[i]=(ps[i]+ps[i-1])%mod;
for(int i=1;i<=n;++i)
{
rt[i]=rt[i-1];
int x=i;
while(x)
{
modif(rt[i],1,n,dfn[top[x]],dfn[x]);
x=fa[top[x]];
}
}
while(q--)
{
int ii=rd(),x=rd(),ans=0;
while(x)
{
ans=(ans+quer(rt[ii],1,n,dfn[top[x]],dfn[x]))%mod;
x=fa[top[x]];
}
printf("%d\n",ans);
}
return 0;
}