CF809E Surprise me!
题目传送门
分析:
我们要求
\(\frac{1}{n(n-1)}\sum_{i=1}^{n}\sum_{j=1}^{n}\varphi(a_{i}a_j)dist(i,j)\)
先看一下怎么求\(\varphi(a_{i}a_j)\)
回归欧拉函数本质的式子:
\(\varphi(xy)=xy\prod_{p|xy}(1-\frac{1}{p})\)
\(\varphi(x)\varphi(y)=xy\prod_{p|x}(1-\frac{1}{p})\prod_{p|y}(1-\frac{1}{p})\)
两式相除:
\(\frac{\varphi(x)\varphi(y)}{\varphi(xy)}=\frac{\prod_{p|x}(1-\frac{1}{p})\prod_{p|y}(1-\frac{1}{p})}{\prod_{p|xy}(1-\frac{1}{p})}\)
(感性)推理一下
\(~~~~\frac{\varphi(x)\varphi(y)}{\varphi(xy)}\)
\(=\prod_{p|gcd(x,y)}(1-\frac{1}{p})\)
\(=\frac{\varphi(gcd(x,y))}{gcd(x,y)}\)
所以
\(\varphi(xy)=\frac{\varphi(x)\varphi(y)gcd(x,y)}{\varphi(gcd(x,y))}\)
于是开始推式子:
\(~~~~\sum_{i=1}^{n}\sum_{j=1}^{n}\varphi(a_{i}a_j)dist(a_i,a_j)\)
\(=\sum_{i=1}^{n}\sum_{j=1}^{n}\frac{\varphi(a_i)\varphi(a_j)gcd(a_i,a_j)}{\varphi(gcd(a_i,a_j))}dist(i,j)\)
枚举\(gcd(a_i,a_j)=d\)
\(=\sum_{d=1}^{n}\frac{d}{varphi(d)}\sum_{i=1}^{n}\sum_{j=1}^{n}[gcd(a_i,a_j)=d]\varphi(a_i)\varphi(a_j)dist(i,j)\)
令\(f(d)=\sum_{i=1}^{n}\sum_{j=1}^{n}[gcd(a_i,a_j)=d]\varphi(a_i)\varphi(a_j)dist(i,j)\)
不好求
我们再令\(F(d)=\sum_{i=1}^{n}\sum_{j=1}^{n}[d|gcd(a_i,a_j)]\varphi(a_i)\varphi(a_j)dist(i,j)\)
可以看出\(F(i)=\sum_{d|i}f(d)\)
于是乎\(f(i)=\sum_{d|i}\mu(\frac{i}{d})F(d)\)
我们知道\(F(d)\)后,便可以\(O(nlogn)\)的时间求出\(f(d)\)
考虑每一个\(d\),由于\(d|gcd(a_i,a_j)\),所以满足\(d|a_i\)的点都会加入,即\(\lfloor\frac{n}{d}\rfloor\)个点
总点数是\(O(nlogn)\)级别
对于每个\(d\),构建虚树,设其中有\(m\)个点,设\(v_i=\varphi(a_i)\)
\(F(d)=\sum_{i=1}^{m}\sum_{j=1}^{m}v_i v_j dist(i,j)\)
\(=\sum_{i=1}^{m}\sum_{j=1}^{m}v_i v_j (dpt(i)+dpt(j)-2dpt(LCA(i,j)))\)
展开
\(=\sum_{i=1}^{m}\sum_{j=1}^{m}v_i v_j dpt(i)+\sum_{i=1}^{m}\sum_{j=1}^{m}v_i v_j dpt(j)-2\sum_{i=1}^{m}\sum_{j=1}^{m}v_i v_j dpt(LCA(i,j))\)
前面俩其实等价
\(=2\sum_{i=1}^{m}\sum_{j=1}^{m}v_i v_j dpt(i)-2\sum_{i=1}^{m}\sum_{j=1}^{m}v_i v_j dpt(LCA(i,j))\)
前面的直接预处理可以算,后面的树形\(dp\)计算每个点为\(LCA\)时整棵子树的总和
于是这道题就解决了,复杂度\(O(nlog^{2}n)\)
一道很好(丧病)的数论大礼包+虚树+树形dp的题
写好了调半天
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
#include<iostream>
#include<map>
#include<bitset>
#include<string>
#define maxn 400005
#define INF 0x3f3f3f3f
#define MOD 1000000007
using namespace std;
inline long long getint()
{
long long num=0,flag=1;char c;
while((c=getchar())<'0'||c>'9')if(c=='-')flag=-1;
while(c>='0'&&c<='9')num=num*10+c-48,c=getchar();
return num*flag;
}
int n;
int a[maxn],id[maxn],pos[maxn],cur;
int F[maxn];
int pri[maxn],phi[maxn],mu[maxn],np[maxn],pcnt;
int fir[maxn],nxt[maxn],to[maxn],cnt;
int sz[maxn],son[maxn],dpt[maxn],fa[maxn],tp[maxn];
vector<int>G[maxn];
int p[maxn],stk[maxn],top;
int sum[maxn],f[maxn],vis[maxn];
int ans;
int ksm(int num,int k)
{
int ret=1;
for(;k;k>>=1,num=1ll*num*num%MOD)if(k&1)ret=1ll*ret*num%MOD;
return ret;
}
void newnode(int u,int v)
{to[++cnt]=v,nxt[cnt]=fir[u],fir[u]=cnt;}
int upd(int x){return x<MOD?x:x-MOD;}
bool cmp(int x,int y){return pos[x]<pos[y];}
void init()
{
mu[1]=phi[1]=1;
for(int i=2;i<maxn;i++)
{
if(!np[i])pri[++pcnt]=i,phi[i]=i-1,mu[i]=MOD-1;
for(int j=1;j<=pcnt&&i*pri[j]<maxn;j++)
{
np[i*pri[j]]=1;
if(i%pri[j])phi[i*pri[j]]=phi[i]*(pri[j]-1),mu[i*pri[j]]=MOD-mu[i];
else{phi[i*pri[j]]=phi[i]*pri[j];break;}
}
}
}
void dfs1(int u)
{
sz[u]=1;pos[u]=++cur;
for(int i=fir[u];i;i=nxt[i])if(to[i]!=fa[u])
{
fa[to[i]]=u,dpt[to[i]]=dpt[u]+1;
dfs1(to[i]),sz[u]+=sz[to[i]];
if(sz[son[u]]<sz[to[i]])son[u]=to[i];
}
}
void dfs2(int u,int ac)
{
tp[u]=ac;if(son[u])dfs2(son[u],ac);
for(int i=fir[u];i;i=nxt[i])if(to[i]!=fa[u]&&to[i]!=son[u])dfs2(to[i],to[i]);
}
int LCA(int u,int v)
{
for(;tp[u]!=tp[v];u=fa[tp[u]])if(dpt[tp[u]]<dpt[tp[v]])swap(u,v);
return dpt[u]<dpt[v]?u:v;
}
void dp(int u)
{
sum[u]=vis[u]*phi[a[u]];
for(int i=0;i<G[u].size();i++)
{
int v=G[u][i];
dp(v),f[u]=upd(f[u]+2ll*sum[u]*sum[v]%MOD),sum[u]=upd(sum[u]+sum[v]);
}
}
inline void solve(int x)
{
int tot=0;
for(int i=x;i<=n;i+=x)p[++tot]=id[i],vis[id[i]]=1;
sort(p+1,p+tot+1,cmp);
for(int i=tot-1;i;i--)p[++tot]=LCA(p[i],p[i+1]);
sort(p+1,p+tot+1,cmp);tot=unique(p+1,p+tot+1)-p-1;
stk[++top]=p[1];
for(int i=2;i<=tot;i++)
{
while(top&&pos[stk[top]]+sz[stk[top]]<=pos[p[i]])top--;
G[stk[top]].push_back(p[i]),stk[++top]=p[i];
}
top=0;
int tmp1=0,tmp2=0;
for(int i=1;i<=tot;i++)if(vis[p[i]])tmp2=upd(tmp2+phi[a[p[i]]]);
for(int i=1;i<=tot;i++)if(vis[p[i]])tmp1=upd(tmp1+1ll*phi[a[p[i]]]*dpt[p[i]]%MOD*tmp2%MOD);
dp(p[1]);tmp2=0;
for(int i=1;i<=tot;i++)tmp2=upd(tmp2+1ll*dpt[p[i]]*f[p[i]]%MOD);
for(int i=1;i<=tot;i++)if(vis[p[i]])tmp2=upd(tmp2+1ll*dpt[p[i]]*phi[a[p[i]]]%MOD*phi[a[p[i]]]%MOD);
F[x]=upd(upd(tmp1*2)-upd(tmp2*2)+MOD);
for(int i=1;i<=tot;i++)G[p[i]].clear(),f[p[i]]=sum[p[i]]=vis[p[i]]=0;
}
int main()
{
init();
n=getint();
for(int i=1;i<=n;i++)id[a[i]=getint()]=i;
for(int i=1;i<n;i++)
{
int u=getint(),v=getint();
newnode(u,v),newnode(v,u);
}
dfs1(1),dfs2(1,1);
for(int i=1;i<=n/2;i++)solve(i);
for(int i=1;i<=n;i++)for(int j=i;j<=n;j+=i)f[i]=upd(f[i]+1ll*F[j]*mu[j/i]%MOD);
for(int i=1;i<=n;i++)ans=upd(ans+1ll*ksm(phi[i],MOD-2)*i%MOD*f[i]%MOD);
printf("%lld\n",1ll*ans*ksm(1ll*n*(n-1)%MOD,MOD-2)%MOD);
}