CF809E Surprise me 题解

大力推式子+莫比乌斯反演+虚树

Statement

给定一棵 \(n\) 个节点的树,每个点有一个权值 \(a[i]\) ,保证 \(a[i]\) 是一个 \(1\dots n\) 的排列。

\(\frac{1}{n(n−1)}∑^n_{i=1}∑_{j≠i}φ(a_i\times a_j)\times dist(i,j) \),对 \(10^9+7\) 取模。

\(n\le 2\times 10^5\)

Solution

据说里面全是套路,但是我啥都第一次

首先忽略 \(\frac 1{n(n-1)}\) ,探究一下 \(\varphi(a_ia_j)\) 咋办

\(\varphi(a_i a_j)=\dfrac {\varphi(a_i)\varphi(a_j)\gcd(a_i,a_j)}{\varphi(\gcd(a_i,a_j))}\)

看到 \(\gcd\) ,直接提出去算;

\[\sum_{d=1}^n \frac d{\varphi(d)}\sum \sum \varphi(a_i)\varphi(a_j)[\gcd(a_i,a_j)==d]dist(i,j) \]

\(f(d)=\sum \sum \varphi(a_i)\varphi(a_j)[\gcd(a_i,a_j)==d]dist(i,j)\)

看到 \([n==x]\) 的形式,直接莫反,设

\[F(x)=\sum_{x|d} f(d)=\sum \sum \varphi(a_i)\varphi(a_j)[x|\gcd(a_i,a_j)]dist(i,j) \]

知道 \(f(x)=\sum_{x|d}F(d)\mu(\frac dx)\),继续算 \(F(x)\)

发现有点拆不动,但其实已经不用拆了,注意到题目中 \(a_i\) 是一个排列,所以 \(x|a_i\) 的只有 \(n\log n\) 个(级数求和)

我们可以考虑直接枚举 \(x\) ,然后把对应的 \(a_i\) 提出来算贡献,也就是算

\[\sum \sum \varphi(a_i)\varphi(a_j)(dep[i]+dep[j]-dep[lca]) \]

把这个式子拆成三个式子,容易发现形如 \(\sum \sum \varphi(a_i)\varphi(a_j) dep[i]\) 的式子可以直接扫一遍得到 $\sum\varphi(a_i)dep[i] $ 和 \(\sum \varphi(a_i)\) 乘起来即可,设这个乘积叫 \(sum\)

考虑咋算 \(res=\sum \sum \varphi(a_i)\varphi(a_j) dep[lca]\)

不妨提一棵虚树出来,然后在虚树上做一个简单树形 DP 即可

那么 \(F(x)=sum\times 2-res\) ,算 \(F\)\(n\log ^2n\) 的,建虚树还有一个 \(log\)

然后 \(O(n\log n)\) 可以反演得到 \(f(x)\),结束

所以总复杂度 \(O(n\log^2 n)\) ,有一点码量,虚树清零的时候小心

Code

#include<bits/stdc++.h>
using namespace std;
const int N = 2e5+5;
const int mod = 1e9+7;

char buf[1<<23],*p1=buf,*p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
int read(){
    int s=0,w=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')w=-1;ch=getchar();}
    while(isdigit(ch))s=s*10+(ch^48),ch=getchar();
    return s*w;
}
void inc(int &a,int b){a=a>=mod-b?a-mod+b:a+b;}
void dec(int &a,int b){a=a>=b?a-b:a+mod-b;}
int ksm(int a,int b){
    int res=1;
    while(b){
        if(b&1)res=1ll*res*a%mod;
        a=1ll*a*a%mod,b>>=1;
    }
    return res;
}

int w[N],rev[N],g[N],f[N];
int n,elen,ans;
bool vis[N];

struct Real_Tree{
    vector<int>Edge[N];
    int siz[N],son[N],top[N],f[N],dep[N],dfn[N];
    int tim;

    void dfs1(int u){
        for(auto v:Edge[u])if(v^f[u])
            dep[v]=dep[f[v]=u]+(siz[v]=1),dfs1(v),siz[u]+=siz[v],
            (siz[v]>siz[son[u]]&&(son[u]=v,1));
    }
    void dfs2(int u,int tp){
        top[u]=tp,dfn[u]=++tim;
        if(son[u])dfs2(son[u],tp);
        for(auto v:Edge[u])if(v^f[u]&&v^son[u])dfs2(v,v);
    }
    int lca(int u,int v){
        while(top[u]^top[v])
            dep[top[u]]<dep[top[v]]?v=f[top[v]]:u=f[top[u]];
        return dep[u]<dep[v]?u:v;
    }

    void build(){
        for(int i=1,u,v;i<n;++i)
            u=read(),v=read(),
            Edge[u].push_back(v),
            Edge[v].push_back(u);
        dfs1(siz[1]=dep[1]=1),dfs2(1,1);
    }
}rt;
struct Math_Fuction{
    int phi[N],mu[N],prime[N];
    bool vis[N];
    int cnt;

    void build(){
        phi[1]=mu[1]=1;
        for(int i=2;i<N;++i){
            if(!vis[i])phi[i]=i-1,mu[i]=-1,prime[++cnt]=i;
            for(int j=1;j<=cnt&&i*prime[j]<N;++j){
                vis[i*prime[j]]=true;
                if(i%prime[j]==0){
                    phi[i*prime[j]]=phi[i]*prime[j];
                    break;
                }
                phi[i*prime[j]]=phi[i]*phi[prime[j]];
                mu[i*prime[j]]=-mu[i];
            }
        }
    }
}mf;
struct Virual_Tree{
    struct Edge{int nex,to;}edge[N<<1];
    int head[N],spc[N],dp[N];
    int elen=1,num,res;
    
    void addedge(int u,int v){
        edge[++elen]=(Edge){head[u],v},head[u]=elen;
        edge[++elen]=(Edge){head[v],u},head[v]=elen;
    }
    void reset(){
        for(int i=1;i<=num;++i)
            head[spc[i]]=vis[spc[i]]=0;
        elen=1,num=0;
    }
    
    void build(){
        sort(spc+1,spc+1+num,[](int x,int y){
            return rt.dfn[x]<rt.dfn[y];});
        for(int i=2;i<=num;++i){
            int l=rt.lca(spc[i],spc[i-1]);
            if(l!=spc[i]&&l!=spc[i-1])spc[++num]=l;
        }
        sort(spc+1,spc+1+num);
        num=unique(spc+1,spc+1+num)-spc-1;
        sort(spc+1,spc+1+num,[](int x,int y){
            return rt.dfn[x]<rt.dfn[y];});
        for(int i=2;i<=num;++i)
            addedge(rt.lca(spc[i],spc[i-1]),spc[i]);
    }
    void dfs(int u,int fath){
        if(vis[u])inc(res,2ll*mf.phi[w[u]]*mf.phi[w[u]]*rt.dep[u]%mod),dp[u]=mf.phi[w[u]];
        for(int e=head[u],v;v=edge[e].to,e;e=edge[e].nex)if(v^fath)
            dfs(v,u),inc(res,4ll*dp[u]*dp[v]%mod*rt.dep[u]%mod),inc(dp[u],dp[v]),dp[v]=0;
    }
    int calc(){
        res=0,dfs(spc[1],spc[1]),dp[spc[1]]=0;
        return res;
    }
}vt;

void calc(int x){
    int sum1=0,sum2=0;
    for(int i=x;i<=n;i+=x)
        vt.spc[++vt.num]=rev[i],vis[rev[i]]=1,
        inc(sum1,1ll*rt.dep[rev[i]]*mf.phi[i]%mod),
        inc(sum2,mf.phi[i]);
    vt.build();
    sum1=1ll*sum1*sum2%mod;
    g[x]=(2ll*sum1-vt.calc()+mod)%mod;
    vt.reset();
}

signed main(){
    n=read();
    for(int i=1;i<=n;++i)
        w[i]=read(),rev[w[i]]=i;
    rt.build(),mf.build();
    for(int i=1;i<=n/2;++i)
        calc(i);
    for(int i=1;i<=n;++i)
        for(int j=i;j<=n;j+=i)
            inc(f[i],(g[j]*mf.mu[j/i]+mod)%mod);
    for(int i=1;i<=n;++i)
        inc(ans,1ll*i*ksm(mf.phi[i],mod-2)%mod*f[i]%mod);
    ans=1ll*ans*ksm(n,mod-2)%mod*ksm(n-1,mod-2)%mod;
    printf("%d\n",ans);
    return 0;
}
posted @ 2022-04-05 17:06  _Famiglistimo  阅读(25)  评论(0编辑  收藏  举报