P4427 [BJOI2018]求和

题目描述

master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的kk 次方和,而且每次的kk 可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了pupil,但pupil 并不会这么复杂的操作,你能帮他解决吗?

输入格式

第一行包含一个正整数\(n\),表示树的节点数。

之后\(n-1\) 行每行两个空格隔开的正整数\(i, j\),表示树上的一条连接点\(i\)和点\(j\)的边。

之后一行一个正整数\(m\)表示询问的数量。

之后每行三个空格隔开的正整数\(i, j, k\),表示询问从点ii 到点jj 的路径上所有节点深度的\(k\) 次方和。由于这个结果可能非常大,输出其对\(998244353\) 取模的结果。

树的节点从\(1\) 开始标号,其中\(1\)号节点为树的根。

输出格式

对于每组数据输出一行一个正整数表示取模后的结果。

输入输出样例

输入 #1复制

5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45

输出 #1复制

33
503245989

说明/提示

样例解释

以下用\(d (i)\) 表示第ii 个节点的深度。

对于样例中的树,有\(d (1) = 0, d (2) = 1, d (3) = 1, d (4) = 2, d (5) = 2\)

因此第一个询问答案为\((2^5 + 1^5 + 0^5)\ mod\ 998244353\),第二个询问答案为\((2^{45} + 1^{45} + 2^{45})\ mod\ 998244353 = 503245989\)

数据范围

对于\(30\%\) 的数据,\(1 \leq n,m \leq 100\)

对于\(60\%\) 的数据,\(1 \leq n,m \leq 1000\)

对于\(100\%\) 的数据,\(1 \leq n,m \leq 300000, 1 \leq k \leq 50\)

另外存在5个不计分的hack数据

提示

数据规模较大,请注意使用较快速的输入输出方式。

敲完树剖求lca华丽走人

我们可以发现,lca的情况无非就是三种

1.\(lca==a\)

2.\(lca==b\)

3.\(lca\)\(a\)\(b\)的上面

1,2情况直接暴力跳就行,

3.情况分别从\(a\)\(lca\)和从\(b\)\(lca\)跳,然后我们发现\(lca\)算了两次,然后再减去一次\(lca\)的贡献就行

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
const int M=400100;
const int N=400100;
int ne[M],head[M],ver[M],idx;
int dep[N],fa[N],son[N],sz[N],top[N];
long long ans;
int n,m;
inline void add(int u,int v)
{
	ne[idx]=head[u];
	ver[idx]=v;
	head[u]=idx;
	idx++;
}

inline void dfs1(int u,int father,int depth)
{
	fa[u]=father;
	sz[u]=1;
	dep[u]=depth;
	for(int i=head[u]; i!=-1; i=ne[i])
	{
		int j=ver[i];
		if(j==father)continue;
		dfs1(j,u,depth+1);
		sz[u]+=sz[j];
		if(sz[son[u]]<sz[j]) son[u]=j;
	}
}

inline void dfs2(int u,int t)
{
	top[u]=t;
	if(!son[u])	return ;
	dfs2(son[u],t);
	for(int i=head[u]; i!=-1; i=ne[i])
	{
		int j=ver[i];
		if(j==fa[u]||j==son[u])continue;
		dfs2(j,j);
	}
}

inline int lca(int u,int v)
{
	while(top[u]!=top[v])
	{
		if(dep[top[u]]<dep[top[v]])
			swap(u,v);
		u=fa[top[u]];
	}
	if(dep[u]<dep[v]) swap(u,v);
	return v;
}

inline int qmi(int a,int b)
{
	int ans=1;
	while(b)
	{
		if(b&1) ans=(long long)ans*a%mod;
		a=(long long)a*a%mod;
		b>>=1;
	}
	return ans;
}
inline int read()
{
	int x=0;
	int f=1;
	char ch;
	ch=getchar();
	while(ch>'9'||ch<'0')
	{
		if(ch=='-')f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=x*10,x=x+ch-'0';
		ch=getchar();
	}
	return x*f;
}
int main()
{

	memset(head,-1,sizeof(head));
	n=read();
	for(register int i=1; i<n; i++)
	{
		int u,v;
		u=read();
		v=read();
		add(u,v);
		add(v,u);
	}
	dfs1(1,0,0);
	dfs2(1,1);
	m=read();
	for(register int i=1; i<=m; i++)
	{
		int a,b,k;
		ans=0;
		a=read();
		b=read();
		k=read();
		int LCA=lca(a,b);
		if(LCA==a)
		{
			for(register int j=dep[a]; j<=dep[b]; j++)
			{
				ans=(ans+qmi(j,k)+mod)%mod;
			}
		}
		else if(LCA==b)
		{
			for(register int j=dep[b]; j<=dep[a]; j++)
			{
				ans=(ans+qmi(j,k)+mod)%mod;
			}
		}
		else
		{
			for(register int j=dep[LCA]; j<=dep[b]; j++)
			{
				ans=(ans+qmi(j,k)%mod+mod)%mod;
			}
			for(register int j=dep[LCA]; j<=dep[a]; j++)
			{
				ans=(ans+qmi(j,k)%mod+mod)%mod;
			}
			ans=(ans-qmi(dep[LCA],k)%mod+mod)%mod;
		}
		printf("%lld\n",ans);
	}
	return 0;
}

posted @ 2020-11-24 12:23  邦的轩辕  阅读(137)  评论(0编辑  收藏  举报