[Codeforces 809E] Surprise me!

题目大意

这题神奇的洛谷有翻译:https://www.luogu.org/problemnew/show/CF809E

\(CF\)题面:https://codeforces.com/problemset/problem/809/E

Solution

奇怪的题...

对于\(\varphi\)有一个长这样的性质:

\[\varphi(ab)=\frac{\varphi(a)\varphi(b)\gcd(a,b)}{\varphi(\gcd(a,b))} \]

证明可以针对每个质因子考虑,然后结合\(\varphi\)的性质就好了。

那么我们就可以愉快的化式子了:

\[\sum_{i=1}^{n}\sum_{j=1}^{n}\frac{\varphi(a_i)\varphi(a_j)\gcd(a_i,a_j)}{\varphi(\gcd(a_i,a_j))}dis(i,j) \]

枚举\(\gcd\)结果:

\[\sum_{d=1}^n\frac{d}{\varphi(d)}\sum_{i=1}^{n}\sum_{j=1}^{n}\varphi(a_i)\varphi(a_j)[\gcd(a_i,a_j)=d]dis(i,j) \]

莫比乌斯反演:

\[\sum_{d=1}^n\frac{d}{\varphi(d)}\sum_{i=1}^{n/d}\sum_{j=1}^{n/d}\sum_{t|i,t|j}\mu(d)\varphi(id)\varphi(jd)dis(b_{id},b_{jd}) \]

这一步比较神奇,注意到\(a_i\)是个排列,那么我们可以求出每个值的位置\(b_i\)满足\(b_{a_i}=i\)

然后枚举\(a_i\)的结果,各种乱推就变成这样了。

\(\sum_t\)提前:

\[\sum_{d=1}^n\frac{d}{\varphi(d)}\sum_{t=1}^{n/d}\mu(t)\sum_{i=1}^{n/dt}\sum_{j=1}^{n/dt}\varphi(idt)\varphi(jdt)dis(b_{idt},b_{jdt}) \]

\(T=dt\)

\[\sum_{T=1}^n\sum_{d|T}\frac{d}{\varphi(d)}\mu(\frac{T}{d})\sum_{i=1}^{n/T}\sum_{j=1}^{n/T}\varphi(iT)\varphi(jT)dis(b_{iT},b_{jT}) \]

到这一步化式子就差不多了,写的好看一点,设:

\[f(T)=\sum_{d|T}\frac{d}{\varphi(d)}\mu(\frac{T}{d}) \]

式子变成:

\[\sum_{T=1}^nf(T)\sum_{i=1}^{n/T}\sum_{j=1}^{n/T}\varphi(iT)\varphi(jT)dis(b_{iT},b_{jT}) \]

注意到对于每个\(T\)只涉及到了\(n/T\)个点,那么总点数就是:

\[O(\sum_{i=1}^{n}\frac{n}{i})= O(n\log n) \]

那么直接建虚树暴力搞就可以了。

总复杂度\(O(n\log ^2 n)\)

代码比较精神污染

#include<bits/stdc++.h>
using namespace std;

void read(int &x) {
    x=0;int f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
    for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
 
void print(int x) {
    if(x<0) putchar('-'),x=-x;
    if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}

#define lf double
#define ll long long 

const int maxn = 2e5+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int mod = 1e9+7;

int a[maxn],n,b[maxn],phi[maxn],isp[maxn],pri[maxn],cnt,f[maxn],iphi[maxn],dfn[maxn],sz[maxn],dep[maxn],mu[maxn];

int qpow(int A,int x) {
	int res=1;
	for(;x;x>>=1,A=1ll*A*A%mod) if(x&1) res=1ll*res*A%mod;
	return res;
}

void sieve() {
	phi[1]=mu[1]=1;
	for(int i=2;i<maxn;i++) {
		if(!isp[i]) pri[++cnt]=i,phi[i]=i-1,mu[i]=-1;
		for(int j=1;j<=cnt&&i*pri[j]<maxn;j++) {
			isp[i*pri[j]]=1;
			if(i%pri[j]==0) {phi[i*pri[j]]=phi[i]*pri[j];break;}
			phi[i*pri[j]]=phi[i]*phi[pri[j]];
			mu[i*pri[j]]=-mu[i];
		}
	}
	for(int i=1;i<maxn;i++) iphi[i]=qpow(phi[i],mod-2);
	for(int d=1;d<maxn;d++)
		for(int T=d;T<maxn;T+=d)
			f[T]=(f[T]+1ll*d*iphi[d]%mod*mu[T/d]%mod)%mod;
}

struct Input_Tree {
	int head[maxn],tot,top[maxn],hs[maxn],fa[maxn],dfn_cnt,res;
	struct edge{int to,nxt;}e[maxn<<1];

	void add(int u,int v) {e[++tot]=(edge){v,head[u]},head[u]=tot;}
	void ins(int u,int v) {add(u,v),add(v,u);}
	
	void dfs1(int x,int Fa) {
		fa[x]=Fa,dep[x]=dep[Fa]+1,sz[x]=1;
		for(int i=head[x];i;i=e[i].nxt)
			if(e[i].to!=Fa) {
				dfs1(e[i].to,x);sz[x]+=sz[e[i].to];
				if(sz[hs[x]]<sz[e[i].to]) hs[x]=e[i].to;
			}
	}

	void dfs2(int x) {
		dfn[x]=++dfn_cnt;
		if(hs[fa[x]]==x) top[x]=top[fa[x]];
		else top[x]=x;
		for(int i=head[x];i;i=e[i].nxt)
			if(e[i].to!=fa[x]) dfs2(e[i].to);
	}

	int lca(int x,int y) {
		while(top[x]!=top[y]) {
			if(dep[top[x]]<dep[top[y]]) swap(x,y);
			x=fa[top[x]];
		}return dep[x]<dep[y]?x:y;
	}
}T;

int cmp(int x,int y) {return dfn[x]<dfn[y];}

struct Virtual_Tree {
	int head[maxn],tot,sta[maxn],use[maxn],top,used,r[maxn],k,sum[maxn],sum2[maxn],res,vis[maxn],g[maxn];
	struct edge{int to,nxt;}e[maxn<<1];

	void add(int u,int v) {e[++tot]=(edge){v,head[u]},head[u]=tot;}
	void ins(int u,int v) {add(u,v),add(v,u);}

	void build() {
		sort(r+1,r+k+1,cmp);
		sta[++top]=1;
		for(int i=1;i<=k;i++) {
			if(r[i]==1) continue;
			int t=T.lca(sta[top],r[i]),pre=-1;
			while(dfn[sta[top]]>dfn[t]&&dfn[sta[top]]<dfn[t]+sz[t]) {
				if(pre!=-1) ins(sta[top],pre);
				pre=sta[top];use[++used]=sta[top];top--;
			}
			if(pre!=-1) ins(t,pre);
			if(sta[top]!=t) sta[++top]=t;
			sta[++top]=r[i];
		}
		int pre=-1;
		while(top) {
			if(pre!=-1) ins(sta[top],pre);
			pre=sta[top];use[++used]=sta[top],top--;
		}
	}

	void clear() {
		for(int v,i=1;i<=used;i++) head[v=use[i]]=0,sum[v]=sum2[v]=vis[v]=g[v]=0;
		top=tot=used=k=0;
	}

	void dfs1(int x,int fa) {
		sum[x]=phi[a[x]]*vis[x],sum2[x]=1ll*(dep[x]-1)*phi[a[x]]*vis[x]%mod;
		for(int i=head[x];i;i=e[i].nxt)
			if(e[i].to!=fa) {
				dfs1(e[i].to,x),sum[x]=(sum[x]+sum[e[i].to])%mod;
				sum2[x]=(sum2[x]+sum2[e[i].to])%mod;
			}
	}

	void dfs2(int x,int fa) {
		if(x!=1) {
			g[x]=(1ll*g[fa]-1ll*sum[x]*(dep[x]-dep[fa])%mod+1ll*(sum[1]-sum[x])*(dep[x]-dep[fa])%mod)%mod;
			if(vis[x]) res=(res+1ll*phi[a[x]]*g[x]%mod)%mod;
		}
		for(int i=head[x];i;i=e[i].nxt)	if(e[i].to!=fa) dfs2(e[i].to,x);
	}
	
	int solve(int t) {
		for(int i=t;i<=n;i+=t) r[++k]=b[i],vis[r[k]]=1;
		build();res=0;dfs1(1,0),g[1]=sum2[1],dfs2(1,0);
		if(vis[1]) res=(res+1ll*g[1]*phi[a[1]]%mod)%mod;
		clear();return res;
	}
}VT;

int main() {
	sieve();
	read(n);for(int i=1;i<=n;i++) read(a[i]);
	for(int i=1;i<=n;i++) b[a[i]]=i;
	for(int i=1,x,y;i<n;i++) read(x),read(y),T.ins(x,y);
	T.dfs1(1,0),T.dfs2(1);
	int ans=0;
	for(int i=1;i<=n;i++) ans=(ans+1ll*f[i]*VT.solve(i)%mod)%mod;
	write((1ll*ans*qpow(1ll*n*(n-1)%mod,mod-2)%mod+mod)%mod);
	return 0;
}
posted @ 2019-03-21 11:06  Hyscere  阅读(168)  评论(0编辑  收藏  举报