【洛谷P5305】旧词
题目
题目链接:https://www.luogu.com.cn/problem/P5305
给定一棵 \(n\) 个点的有根树,节点标号 \(1 \sim n\),\(1\) 号节点为根。
给定常数 \(k\)。
给定 \(Q\) 个询问,每次询问给定 \(x,y\)。
求:
\[\sum\limits_{i \le x} \text{depth}(\text{lca}(i,y))^k
\]
\(\text{lca}(x,y)\) 表示节点 \(x\) 与节点 \(y\) 在有根树上的最近公共祖先。
\(\text{depth}(x)\) 表示节点 \(x\) 的深度,根节点的深度为 \(1\)。
由于答案可能很大,你只需要输出答案模 \(998244353\) 的结果。
\(n,Q\leq 5\times 10^4;1\leq k\leq 10^9\)。
思路
和 洛谷P4211 LCA 这道题十分相似,唯一的区别就是在 \(\text{dep}\) 外面套上了一个 \(k\) 次方。
原题的做法是离线然后从小到大考虑 \(i\),树剖+线段树把根节点到 \(i\) 的路径全部加一,询问根节点到 \(r\) 的权值和减去根节点到 \(l-1\) 的权值和。
那么依然考虑是否能给每一个节点一个权值,这样从 \(x\) 到根节点的路径权值和恰好等于 \(\text{dep}(x)^k\)。
那么显然对于一个点 \(x\),我们把它的权值设为 \(\text{dep}(x)^k-(\text{dep}(x)-1)^k\) 即可。
那么其他部分依然一样,只不过线段树上一个区间 \([l,r]\) 的权值和就变成了区间内点的权值和乘区间加一的次数。依然可以轻松维护。
时间复杂度 \(O(Q\log^2 n)\)。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=50010,MOD=998244353;
int n,m,Q,tot,head[N],ans[N],top[N],son[N],siz[N],dep[N],fa[N],id[N],rk[N];
struct edge
{
int next,to;
}e[N];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
struct node
{
int x,y,id;
}a[N];
bool cmp(node x,node y)
{
return x.x<y.x;
}
void dfs1(int x)
{
dep[x]=dep[fa[x]]+1; siz[x]=1;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
dfs1(v);
siz[x]+=siz[v];
if (siz[v]>siz[son[x]]) son[x]=v;
}
}
void dfs2(int x,int tp)
{
top[x]=tp; id[x]=++tot; rk[tot]=x;
if (son[x]) dfs2(son[x],tp);
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=son[x]) dfs2(v,v);
}
}
ll fpow(ll x,ll k)
{
ll res=1;
for (;k;k>>=1,x=x*x%MOD)
if (k&1) res=res*x%MOD;
return res;
}
struct SegTree
{
int sum[N*4],ans[N*4],lazy[N*4];
void pushup(int x)
{
sum[x]=(sum[x*2]+sum[x*2+1])%MOD;
ans[x]=(ans[x*2]+ans[x*2+1])%MOD;
}
void pushdown(int x)
{
if (lazy[x])
{
ans[x*2]=(ans[x*2]+1LL*sum[x*2]*lazy[x])%MOD;
ans[x*2+1]=(ans[x*2+1]+1LL*sum[x*2+1]*lazy[x])%MOD;
lazy[x*2]=(lazy[x*2]+lazy[x])%MOD;
lazy[x*2+1]=(lazy[x*2+1]+lazy[x])%MOD;
lazy[x]=0;
}
}
void build(int x,int l,int r)
{
if (l==r)
{
int d=dep[rk[l]];
sum[x]=(fpow(d,m)-fpow(d-1,m)+MOD)%MOD;
return;
}
int mid=(l+r)>>1;
build(x*2,l,mid); build(x*2+1,mid+1,r);
pushup(x);
}
void update(int x,int l,int r,int ql,int qr)
{
if (ql<=l && qr>=r)
{
lazy[x]++; ans[x]=(ans[x]+sum[x])%MOD;
return;
}
pushdown(x);
int mid=(l+r)>>1;
if (ql<=mid) update(x*2,l,mid,ql,qr);
if (qr>mid) update(x*2+1,mid+1,r,ql,qr);
pushup(x);
}
int query(int x,int l,int r,int ql,int qr)
{
if (ql<=l && qr>=r) return ans[x];
pushdown(x);
int mid=(l+r)>>1,res=0;
if (ql<=mid) res+=query(x*2,l,mid,ql,qr);
if (qr>mid) res+=query(x*2+1,mid+1,r,ql,qr);
return res%MOD;
}
}seg;
void upd(int x)
{
for (;x;x=fa[top[x]])
seg.update(1,1,n,id[top[x]],id[x]);
}
int query(int x)
{
int res=0;
for (;x;x=fa[top[x]])
res=(res+seg.query(1,1,n,id[top[x]],id[x]))%MOD;
return res;
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d%d",&n,&Q,&m);
for (int i=2;i<=n;i++)
{
scanf("%d",&fa[i]);
add(fa[i],i);
}
for (int i=1;i<=Q;i++)
{
scanf("%d%d",&a[i].x,&a[i].y);
a[i].id=i;
}
sort(a+1,a+1+Q,cmp);
tot=0; dfs1(1); dfs2(1,1);
seg.build(1,1,n);
for (int i=1,j=1;i<=Q;i++)
{
for (;j<=a[i].x;j++) upd(j);
ans[a[i].id]=query(a[i].y);
}
for (int i=1;i<=Q;i++)
cout<<ans[i]<<"\n";
return 0;
}