YbtOJ#723-欧拉之树【莫比乌斯反演,虚树】

正题

题目链接:http://www.ybtoj.com.cn/contest/121/problem/2


题目大意

给出\(n\)个点的一棵树,每个点有一个权值\(a_i\),求

\[\sum_{i=1}^n\sum_{j=1}^ndis(i,j)\times \varphi(a_i\times a_j) \]

\(2\leq n\leq 2\times 10^5\)\(a\)恰好是一个排列。


解题思路

一个十分显然的结论就是\(\varphi(x\times y)=\varphi(x)\times \varphi(y)\times \frac{gcd(x,y)}{\varphi(gcd(x,y))}\)。(相同的质因子只保留一个数\(p-1\)的就好了)

然后顺便把点编号换一下使得\(a_i=i\)再枚举约数就是

\[\sum_{d=1}^n\frac{\varphi(d)}{d}\sum_{i=1}^n\sum_{j=1}^{n}dis(i,j) \varphi(i)\varphi(j)\times [gcd(i,j)=d] \]

然后就可以莫反了,定义

\[g_d=\sum_{d|i}^n\sum_{d|j}^ndis(i,j)\varphi(i)\varphi(j) \]

\[g_d=\sum_{d|i}^n\sum_{d|j}^n(dep_i+dep_j-2dep_{lca(i,j)})\varphi(i)\varphi(j) \]

\[g_d=2\sum_{d|i}^ndep_{i}\varphi(i)\sum_{d|j}^n\varphi(j)-2\sum_{k=1}^n\sum_{i=1}^n\sum_{j=1}^n[lca(i,j)=k]dep_{k}\varphi(i)\varphi(j) \]

把所有\(d\)倍的点加入虚树,然后用树形\(dp\)计算后面那个东西,前面那个可以直接算。

然后答案就是

\[\sum_{d=1}^n\frac{\varphi(d)}{d}\sum_{d|i}g_i \]

时间复杂度\(O(n\log^2 n)\),有点卡常。


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cctype>
#define ll long long
#pragma GCC optimize(2)
%:pragma GCC optimize(3)
%:pragma GCC optimize("Ofast")
%:pragma GCC optimize("inline")
using namespace std;
const int N=4e5+10,T=20,P=1e9+7;
int read() {
	int x=0,f=1; char c=getchar();
	while(!isdigit(c)) {if(c=='-')f=-f;c=getchar();}
	while(isdigit(c)) x=(x<<1)+(x<<3)+c-48,c=getchar();
	return x*f;
}
struct node{
	int to,next;
}a[N<<1];
int n,m,tot,top,dfc,ls[N],rfl[N];
int mu[N],phi[N],pri[N],prn,s[N],p[N];
int stn,lg[N],wz[N],rfn[N],dep[N],f[N][T];
ll S[N],dp[N],g[N],ans;bool v[N],mark[N];
ll power(ll x,ll b){
	ll ans=1;
	while(b){
		if(b&1)ans=ans*x%P;
		x=x*x%P;b>>=1;
	}
	return ans;
}
void prime(){
	mu[1]=phi[1]=1;
	for(int i=2;i<=n;i++){
		if(!v[i])pri[++prn]=i,phi[i]=i-1,mu[i]=-1;
		for(int j=1;j<=prn&&i*pri[j]<=n;j++){
			v[i*pri[j]]=1;
			if(i%pri[j]==0){
				phi[i*pri[j]]=phi[i]*pri[j];
				break;
			}
			phi[i*pri[j]]=phi[i]*(pri[j]-1);
			mu[i*pri[j]]=-mu[i];
		}
	}
	return;
}
void addl(int x,int y){
	a[++tot].to=y;
	a[tot].next=ls[x];
	ls[x]=tot;return;
}
bool cmp(int x,int y)
{return rfn[x]<rfn[y];}
void dfs(int x,int fa){
	dep[x]=dep[fa]+1;rfn[x]=++dfc;
	f[++stn][0]=x;wz[x]=stn;
	for(int i=ls[x];i;i=a[i].next){
		int y=a[i].to;
		if(y==fa)continue;
		dfs(y,x);f[++stn][0]=x;
	}
	return;
}
int LCA(int l,int r){
	l=wz[l];r=wz[r];
	if(l>r)swap(l,r);
	int z=lg[r-l+1],x=f[l][z],y=f[r-(1<<z)+1][z];
	return dep[x]<dep[y]?x:y;
}
void Ins(int x){
	if(!top){s[++top]=x;return;}
	int lca=LCA(x,s[top]);
	while(top>1&&dep[s[top-1]]>dep[lca])
		addl(s[top-1],s[top]),top--;
	if(dep[s[top]]>dep[lca])addl(lca,s[top]),top--;
	if((!top)||s[top]!=lca)s[++top]=lca;
	s[++top]=x;return;
}
void calc(int x,ll &ans){
	if(mark[x])S[x]=phi[x],dp[x]=1ll*phi[x]*phi[x]%P;
	else S[x]=dp[x]=0;
	for(int i=ls[x];i;i=a[i].next){
		int y=a[i].to;calc(y,ans);
		(dp[x]+=S[x]*S[y]*2ll%P)%=P;
		S[x]=(S[x]+S[y])%P;
	}
	(ans+=P-1ll*dp[x]*dep[x]%P)%=P;
	ls[x]=mark[x]=0;return;
}
signed main()
{
	freopen("sm.in","r",stdin);
	freopen("sm.out","w",stdout);
	n=read();prime();
	for(int i=1;i<=n;i++){
		int x=read();
		rfl[i]=x;
	}
	for(int i=1;i<n;i++){
		int x=read(),y=read();
		x=rfl[x];y=rfl[y];
		addl(x,y);addl(y,x);
	}
	dfs(1,1);
	for(int j=1;(1<<j)<=stn;j++)
		for(int i=1;i+(1<<j)-1<=stn;i++){
			int x=f[i][j-1],y=f[i+(1<<j-1)][j-1];
			f[i][j]=(dep[x]<dep[y])?x:y;
		}
	for(int i=2;i<=stn;i++)lg[i]=lg[i>>1]+1;	
	memset(ls,0,sizeof(ls));
	for(int k=1;k<=n;k++){
		m=top=tot=0;ll sum=0;
		for(int i=k;i<=n;i+=k)
			p[++m]=i,sum+=phi[i];
		sort(p+1,p+1+m,cmp);sum%=P;
		if(p[1]!=1)s[++top]=1;
		for(int i=1;i<=m;i++){
			Ins(p[i]);mark[p[i]]=1;
			(g[k]+=1ll*phi[p[i]]*dep[p[i]]%P*sum%P)%=P;
		}
		while(top>1)addl(s[top-1],s[top]),top--;
		calc(1,g[k]);g[k]=g[k]*2ll%P;
	}
	for(int i=1;i<=n;i++){
		ll tmp=0;
		for(int j=i;j<=n;j+=i)
			(tmp+=mu[j/i]*g[j]%P)%=P;
		(ans+=tmp*i%P*power(phi[i],P-2)%P)%=P;
	}
	printf("%d\n",(ans+P)%P*power(1ll*n*(n-1)%P,P-2)%P);
	return 0;
}
posted @ 2021-02-22 19:52  QuantAsk  阅读(51)  评论(0编辑  收藏  举报