2022 ICPC 网络赛(II) H Fast Fourier Transform题解

简要题意

给你一棵树,你可以选若干节点为关键点,定义一个选点方案的价值为:所有路径上没有关键点的点对的距离之和。求所有选点方案的价值之和。

题解

一开始和队友都读错题了,以为在一个方案中一条边只会贡献一次,然后这样直接计算每条边的贡献就秒掉了,于是开始疑惑为什么大伙都没有切这个题。

然后邓老师写完代码了发现样例不对,手模样例发现手模也算不对,然后就意识到题读错了,还有五分钟就结束了所以就摆烂了。

题目的关键是要意识到,一对点的贡献只与他们两点之间的距离有关。

显然,只要他们路径上的点不被选,剩下的点随便选都无所谓。

所以若距离为 \(d\) ,它们的贡献可以写成

\[d\times \sum_i (i+2)\times {n-d-1 \choose i}\\ =d\times(n-d+3)\times2^{n-d-2} \]

形如 \(x\cdot y\cdot 2^y\) ,记录当前点的答案为 \(\sum x\cdot y \cdot 2^y\) ,如果想要上传答案给父亲就是 \(x+1,y-1\) ,同时这个维护 \(\sum x\cdot 2^y,\sum y\cdot 2^y,\sum 2^y\) 就能算了。

换根 dp 在第二次 dfs 的时候算当前节点和子树外节点的贡献的时候,就是把父亲节点的所有贡献减去父亲节点与当前子树内节点形成的点对的贡献,再使 \(x+1,y-1\) 即可。

然后是代码

#include <bits/stdc++.h>
#define N 1000006
using namespace std;
typedef long long ll;
int n,o2,yy;
const ll mod=998244353;
vector<int> ed[N];
int fa[N];
struct ANS{
	ll xy2,x2,y2,_2;
	friend ANS operator +(ANS a,ANS b){ 
		a.x2=(a.x2+b.x2)%mod,
		a.xy2=(a.xy2+b.xy2)%mod,
		a.y2=(a.y2+b.y2)%mod,
		a._2=(a._2+b._2)%mod;
		return a;
	}
	friend ANS operator -(ANS a,ANS b){ 
		a.x2=(a.x2-b.x2+mod)%mod,
		a.xy2=(a.xy2-b.xy2+mod)%mod,
		a.y2=(a.y2-b.y2+mod)%mod,
		a._2=(a._2-b._2+mod)%mod;
		return a;
	}
}ans1[N],ans2[N];
ANS add(ANS x){
	ANS a;
	a.xy2=((x.xy2-x.x2+mod+x.y2-x._2+mod)%mod*o2)%mod;
	a.x2=(x.x2+x._2)*o2%mod;
	a.y2=(x.y2-x._2+mod)*o2%mod;
	a._2=x._2*o2%mod;
	a=a+ans1[0];
	return a;
}
ll ksm(ll x,ll y){
	ll res=1;
	while(y){
		if(y&1) res=res*x%mod;
		x=x*x%mod;
		y>>=1;
	}
	return res;
}
void dfs1(int x){
	for(int y: ed[x]){
		if(y==fa[x]) continue;
		fa[y]=x,dfs1(y);
		ans1[x]=ans1[x]+add(ans1[y]);
	}
}
void dfs2(int x){
	if(x==1) ans2[x]=ans1[x];
	else ans2[x]=add(ans2[fa[x]]-add(ans1[x]))+ans1[x];
	for(int y: ed[x]){
		if(y==fa[x]) continue;
		dfs2(y);
	}
}
int main(){
	o2=(mod+1)/2;
	scanf("%d",&n);
	yy=n+2;
	ans1[0]._2=ksm(2,yy),ans1[0].y2=yy*ksm(2,yy)%mod;
	ans1[0].xy2=ans1[0].y2,ans1[0].x2=ans1[0]._2;
	int u,v;
	for(int i=1;i<n;i++){
		scanf("%d %d",&u,&v);
		ed[u].push_back(v),ed[v].push_back(u);
	}
	dfs1(1);
	dfs2(1);
	ll ans=0;
	for(int i=1;i<=n;i++) ans=(ans+ans2[i].xy2)%mod;
	o2=ksm(2,6),o2=ksm(o2,mod-2);
	ans=ans*o2%mod;
	cout<<ans;
}
posted @ 2022-10-10 20:20  缙云山车神  阅读(130)  评论(0编辑  收藏  举报