朝花夕拾:NHOI 2022 T6

原题
题意:

题目描述

给定一颗树有 \(n\) 个结点,每个结点上有一个权值 \(a_i\), 对于每条至少包含两个点简单路径,它的贡献为 路径上点的数量(包括端点)\(\times\)路径上所有点的 \(a_i\) 的最大公约数(gcd)。
求所有简单路径的贡献之和,对 \(998244353\) 取模。
看到了 \(gcd\) 这个很难搞的东西 一开始的想法就是点分治了 但是肯定是不行的 因为 \(gcd\) 是很不好维护的
换一种方法 考虑欧拉反演

每次枚举一个公约数 然后把是这个公约数倍数的点给找出来
这个过程可以用 vector 预处理
然后转化为在这些点中 找出所有路径

然后找路径很麻烦 考虑转化成每个点的贡献 算一下每个点被多少条路径经过即可
这个过程随便搞就行了 时间复杂度是 \(O(\text{点数})\)
因为每个数 \(\leq 10^5\) 因此因数不会超过 \(128\) 个 因此时间复杂度均摊是 \(O(wn)\)\(w=128\)

Code

#include<bits/stdc++.h>
#define N 100005
#define ll long long
using namespace std;
ll mod=998244353;
ll sum,ans;
int n,a[N],m,all;
int isp[N],phi[N];
int p[N],len;
int head[N],tot=1;
struct edge{
	int to,next;
}e[N*2];
void add(int u,int v)
{
	e[tot]=(edge){v,head[u]};
	head[u]=tot++;
}
vector <int> vec[N];
void init(int n)
{
	isp[1]=phi[1]=1;
	for(int i=2;i<=n;i++)
	{
		if(!isp[i])
		{
			phi[i]=i-1;
			p[++len]=i;
		}
		for(int j=1;j<=len&&p[j]*i<=n;j++)
		{
			isp[i*p[j]]=1;
			if(i%p[j]==0)
			{
				phi[i*p[j]]=phi[i]*p[j];
				break;
			}
			phi[i*p[j]]=phi[i]*phi[p[j]];
		}
	}
}
int col[N],siz[N],root[N];
void dfs(int now,int fa,int rt)
{
	siz[now]=1;
	root[now]=rt;
	for(int i=head[now];i;i=e[i].next)
	{
		int son=e[i].to;
		if(son==fa||!col[son]) continue;
		dfs(son,now,rt);
		siz[now]+=siz[son];
	}
}
void dfs2(int now,int fa)
{
	ll s=1;
	for(int i=head[now];i;i=e[i].next)
	{
		int son=e[i].to;
		if(son==fa||!col[son]) continue;
		dfs2(son,now);
		sum=(sum+1ll*siz[son]*s)%mod;
		s+=siz[son];
	}
	sum=(sum+s*(siz[root[now]]-siz[now]))%mod;
}
int main()
{
	scanf("%d",&n);
	for(int i=1;i<=n;i++)
		scanf("%d",&a[i]),m=max(m,a[i]);
	init(m);
	for(int i=1;i<n;i++)
	{
		int u,v;
		scanf("%d%d",&u,&v);
		add(u,v);
		add(v,u);
	}
	for(int i=1;i<=n;i++)
	{
		for(int j=1;j<=sqrt(a[i]);j++)
		if(a[i]%j==0)
		{
			vec[j].push_back(i);
			if(j*j!=a[i]) vec[a[i]/j].push_back(i);
		}
	}
	for(int i=1;i<=m;i++)
	{
		sum=0;
		for(int j=0;j<vec[i].size();j++) col[vec[i][j]]=1;
		for(int j=0;j<vec[i].size();j++)
		{
			int u=vec[i][j];
			if(!siz[u])dfs(u,0,u);
		}
		for(int j=0;j<vec[i].size();j++)
		{
			int u=vec[i][j];
			if(u==root[u]) dfs2(u,0);
		}
		ans=(ans+sum*phi[i])%mod;
		for(int j=0;j<vec[i].size();j++)
		{
			int u=vec[i][j];
			col[u]=siz[u]=root[u]=0;
		}
	}
	printf("%lld\n",ans);
	return 0;
}
posted @ 2023-08-19 22:59  g1ove  阅读(50)  评论(0编辑  收藏  举报