luogu P5305 [GXOI/GZOI2019]旧词
题面传送门
你会发现这道题和LNOI2014某题很像。
但是那个\(k\)次方很难处理。
考虑\(k=1\)的情况,就是那道题。
照样差分,但是这次差分不是那么差,而是每个点的权值改成\(d_i^k-(d_i-1)^k\)这东西就可以实现了。
因为加到一个点时这个点到根节点的路径都会被加。而这个值又恰好等于深度的\(k\)次方。
那么把原来的线段树改成带权的就好了。
代码实现:
#include<cstdio>
#include<cstring>
#include<vector>
#define mod 998244353
using namespace std;
int n,m,k,x,y,z;
int siz[50039],d[50039],son[50039],top[50039],id[50039],fa[50039],idea,ids[50039];
long long ans[50039],f[200039],sum[200039],g[200039];
inline void read(int &x){
char s=getchar();x=0;
while(s<'0'||s>'9') s=getchar();
while(s>='0'&&s<='9') x=(x<<3)+(x<<1)+(s^48),s=getchar();
}
struct yyy{int to,z;};
struct ljb{
int head,h[50039];
yyy f[100039];
inline void add(int x,int y){
f[++head]=(yyy){y,h[x]};
h[x]=head;
}
}s;
inline void dfs1(int x,int last){
d[x]=d[last]+1;
fa[x]=last;
siz[x]=1;
int cur=s.h[x],pus=0;
yyy tmp;
while(cur!=-1){
tmp=s.f[cur];
if(tmp.to!=last){
dfs1(tmp.to,x);
siz[x]+=siz[tmp.to];
if(siz[pus]<siz[tmp.to]) pus=tmp.to;
}
cur=tmp.z;
}
son[x]=pus;
}
inline void dfs2(int x,int last){
top[x]=last;
id[x]=++idea;
ids[idea]=d[x];
if(!son[x]) return;
dfs2(son[x],last);
int cur=s.h[x];
yyy tmp;
while(~cur){
tmp=s.f[cur];
if(tmp.to!=son[x]&&tmp.to!=fa[x]) dfs2(tmp.to,tmp.to);
cur=tmp.z;
}
}
struct ques{int to,num;}tmp;
vector<ques> fs[50039];
inline void swap(int &x,int &y){x^=y^=x^=y;}
inline long long pow(long long x,int y){
long long ans=1;
while(y){
if(y&1) ans=ans*x%mod;
x=x*x%mod;
y>>=1;
}
return ans;
}
inline void jianshu(int l,int r,int now){
if(l==r){
g[now]=(pow(ids[l],k)-pow(ids[l]-1,k)+mod)%mod;
return;
}
int m=(l+r)>>1;
jianshu(l,m,now<<1);jianshu(m+1,r,now<<1|1);
g[now]=(g[now<<1]+g[now<<1|1])%mod;
}
inline void push(int now){
if(f[now]){
f[now<<1]+=f[now];f[now<<1|1]+=f[now];
sum[now<<1]+=g[now<<1]*f[now];sum[now<<1|1]+=g[now<<1|1]*f[now];
f[now]=0;
}
}
inline void get(int x,int y,int l,int r,int now){
if(x<=l&&r<=y) {
f[now]++;
sum[now]+=g[now];
return;
}
push(now);
int m=(l+r)>>1;
if(x<=m) get(x,y,l,m,now<<1);
if(y>m) get(x,y,m+1,r,now<<1|1);
sum[now]=sum[now<<1]+sum[now<<1|1];
}
inline long long find(int x,int y,int l,int r,int now){
if(x<=l&&r<=y) return sum[now];
int m=l+r>>1;long long fs=0;
push(now);
if(x<=m) fs+=find(x,y,l,m,now<<1);
if(y>m) fs+=find(x,y,m+1,r,now<<1|1);
return fs;
}
inline void gets(int x,int y){
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
get(id[top[x]],id[x],1,n,1);
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
get(id[x],id[y],1,n,1);
}
inline long long finds(int x,int y){
long long ans=0;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
ans+=find(id[top[x]],id[x],1,n,1);
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
return ans+find(id[x],id[y],1,n,1);
}
int main(){
// freopen("1.in","r",stdin);
register int i,j;
memset(s.h,-1,sizeof(s.h));
scanf("%d%d%d",&n,&m,&k);
for(i=2;i<=n;i++) read(x),s.add(i,x),s.add(x,i);
dfs1(1,0);dfs2(1,1);jianshu(1,n,1);
for(i=1;i<=m;i++){
read(x);read(z);
fs[x].push_back((ques){z,i});
}
for(i=1;i<=n;i++){
gets(1,i);
for(j=0;j<fs[i].size();j++){
tmp=fs[i][j];
ans[tmp.num]=finds(1,tmp.to)%mod;
}
}
for(i=1;i<=m;i++) printf("%lld\n",ans[i]);
}