【洛谷P5305】旧词

题目

题目链接:https://www.luogu.com.cn/problem/P5305
给定一棵 \(n\) 个点的有根树,节点标号 \(1 \sim n\)\(1\) 号节点为根。
给定常数 \(k\)
给定 \(Q\) 个询问,每次询问给定 \(x,y\)
求:

\[\sum\limits_{i \le x} \text{depth}(\text{lca}(i,y))^k \]

\(\text{lca}(x,y)\) 表示节点 \(x\) 与节点 \(y\) 在有根树上的最近公共祖先。
\(\text{depth}(x)\) 表示节点 \(x\) 的深度,根节点的深度为 \(1\)
由于答案可能很大,你只需要输出答案模 \(998244353\) 的结果。
\(n,Q\leq 5\times 10^4;1\leq k\leq 10^9\)

思路

洛谷P4211 LCA 这道题十分相似,唯一的区别就是在 \(\text{dep}\) 外面套上了一个 \(k\) 次方。
原题的做法是离线然后从小到大考虑 \(i\),树剖+线段树把根节点到 \(i\) 的路径全部加一,询问根节点到 \(r\) 的权值和减去根节点到 \(l-1\) 的权值和。
那么依然考虑是否能给每一个节点一个权值,这样从 \(x\) 到根节点的路径权值和恰好等于 \(\text{dep}(x)^k\)
那么显然对于一个点 \(x\),我们把它的权值设为 \(\text{dep}(x)^k-(\text{dep}(x)-1)^k\) 即可。
那么其他部分依然一样,只不过线段树上一个区间 \([l,r]\) 的权值和就变成了区间内点的权值和乘区间加一的次数。依然可以轻松维护。
时间复杂度 \(O(Q\log^2 n)\)

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=50010,MOD=998244353;
int n,m,Q,tot,head[N],ans[N],top[N],son[N],siz[N],dep[N],fa[N],id[N],rk[N];

struct edge
{
	int next,to;
}e[N];

void add(int from,int to)
{
	e[++tot]=(edge){head[from],to};
	head[from]=tot;
}

struct node
{
	int x,y,id;
}a[N];

bool cmp(node x,node y)
{
	return x.x<y.x;
}

void dfs1(int x)
{
	dep[x]=dep[fa[x]]+1; siz[x]=1;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		dfs1(v);
		siz[x]+=siz[v];
		if (siz[v]>siz[son[x]]) son[x]=v;
	}
}

void dfs2(int x,int tp)
{
	top[x]=tp; id[x]=++tot; rk[tot]=x;
	if (son[x]) dfs2(son[x],tp);
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=son[x]) dfs2(v,v);
	}
}

ll fpow(ll x,ll k)
{
	ll res=1;
	for (;k;k>>=1,x=x*x%MOD)
		if (k&1) res=res*x%MOD;
	return res;
}

struct SegTree
{
	int sum[N*4],ans[N*4],lazy[N*4];
	
	void pushup(int x)
	{
		sum[x]=(sum[x*2]+sum[x*2+1])%MOD;
		ans[x]=(ans[x*2]+ans[x*2+1])%MOD;
	}
	
	void pushdown(int x)
	{
		if (lazy[x])
		{
			ans[x*2]=(ans[x*2]+1LL*sum[x*2]*lazy[x])%MOD;
			ans[x*2+1]=(ans[x*2+1]+1LL*sum[x*2+1]*lazy[x])%MOD;
			lazy[x*2]=(lazy[x*2]+lazy[x])%MOD;
			lazy[x*2+1]=(lazy[x*2+1]+lazy[x])%MOD;
			lazy[x]=0;
		}
	}
	
	void build(int x,int l,int r)
	{
		if (l==r)
		{
			int d=dep[rk[l]];
			sum[x]=(fpow(d,m)-fpow(d-1,m)+MOD)%MOD;
			return;
		}
		int mid=(l+r)>>1;
		build(x*2,l,mid); build(x*2+1,mid+1,r);
		pushup(x);
	}
	
	void update(int x,int l,int r,int ql,int qr)
	{
		if (ql<=l && qr>=r)
		{
			lazy[x]++; ans[x]=(ans[x]+sum[x])%MOD;
			return;
		}
		pushdown(x);
		int mid=(l+r)>>1;
		if (ql<=mid) update(x*2,l,mid,ql,qr);
		if (qr>mid) update(x*2+1,mid+1,r,ql,qr);
		pushup(x);
	}
	
	int query(int x,int l,int r,int ql,int qr)
	{
		if (ql<=l && qr>=r) return ans[x];
		pushdown(x);
		int mid=(l+r)>>1,res=0;
		if (ql<=mid) res+=query(x*2,l,mid,ql,qr);
		if (qr>mid) res+=query(x*2+1,mid+1,r,ql,qr);
		return res%MOD;
	}
}seg;

void upd(int x)
{
	for (;x;x=fa[top[x]])
		seg.update(1,1,n,id[top[x]],id[x]);
}

int query(int x)
{
	int res=0;
	for (;x;x=fa[top[x]])
		res=(res+seg.query(1,1,n,id[top[x]],id[x]))%MOD;
	return res;
}

int main()
{
	memset(head,-1,sizeof(head));
	scanf("%d%d%d",&n,&Q,&m);
	for (int i=2;i<=n;i++)
	{
		scanf("%d",&fa[i]);
		add(fa[i],i);
	}
	for (int i=1;i<=Q;i++)
	{
		scanf("%d%d",&a[i].x,&a[i].y);
		a[i].id=i;
	}
	sort(a+1,a+1+Q,cmp);
	tot=0; dfs1(1); dfs2(1,1);
	seg.build(1,1,n);
	for (int i=1,j=1;i<=Q;i++)
	{
		for (;j<=a[i].x;j++) upd(j);
		ans[a[i].id]=query(a[i].y);
	}
	for (int i=1;i<=Q;i++)
		cout<<ans[i]<<"\n";
	return 0;
}
posted @ 2021-06-02 12:20  stoorz  阅读(58)  评论(0编辑  收藏  举报