2019.4.24 一题(CF 809E)——推式子+虚树

题目:http://codeforces.com/contest/809/problem/E

  原来以为可以每个质因子分开给答案贡献。

  大概就是把有这个质因子的数都拿出来建虚树,这样虚树的总点数是 nlogn 的。

  定义 b[ i ] 表示 i 点原来的权值分解出的 pt ,其中 p 是目前在做的质因子。那么要求虚树里的 \( \sum\limits_{u} \sum\limits_{v} dis(u,v)*b[u]*b[v] \) 。最后乘 \( \frac{p-1}{p} \) 即可。

  把 dis( u , v ) 拆成 dep[ u ] + dep[ v ] - 2*dep[ lca ] ,就是每个点贡献 \( dep[cr]*b[cr]*\sum\limits_{i!=cr}b[i] \) ,\( -2*dep[cr]*( (\sum\limits_{i \in tree_cr}b[i])^2 - \sum\limits_{i \in tree_cr}b[i]^2 ) \)

  还要考虑当前质因子的 “虚树上的点与虚树外的点” 的贡献,就是 \( b[cr]*\sum\limits_{i}dis(i,cr) \) 。要求 “该点与虚树外的点的距离和” ,可以用 “该点与所有点的距离和” - “该点与虚树上点的距离和” , 换根 DP 一番即可。

  然后把各种质因子的贡献加起来。

  试了一下自己造的样例:

4

1 2 3 4

1 2

1 3

3 4

  发现可以。也没试试题面的样例,就开始写。写完发现是不对的。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ll long long
#define pb push_back
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
const int N=2e5+5,K=17,mod=1e9+7;
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}

int n,a[N],hd[N],xnt,to[N<<1],nxt[N<<1],ans;
int tim,dfn[N],dep[N],pre[N][K+5],siz[N],sm[N];
int pri[N],mnd[N],cnt,dy[N]; bool vis[N];
struct Node{
  int v,w;
  Node(int v=0,int w=0):v(v),w(w) {}
  bool operator< (const Node &b)const
  {return dfn[(*this).v]<dfn[b.v];}
};
vector<Node> vt[N];
namespace VT{
  Node q[N]; int tot,sta[N],top,ret,p;
  int a[N],hd[N],xnt,to[N<<1],nxt[N<<1];
  int siz[N],s2[N],alsm,vl[N],v2[N];
  void add(int x,int y)
  {to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
  int get_lca(int x,int y)
  {
    if(dep[x]<dep[y])swap(x,y);
    for(int t=K;t>=0;t--)
      if(dep[pre[x][t]]>=dep[y])x=pre[x][t];
    if(x==y)return x;
    for(int t=K;t>=0;t--)
      if(pre[x][t]!=pre[y][t])
    x=pre[x][t], y=pre[y][t];
    return pre[x][0];
  }
  void build()
  {
    xnt=alsm=0; sort(q+1,q+tot+1); int st,u=q[1].v;
    if(u!=1){sta[top=1]=1; a[1]=0; hd[1]=0; st=1;}//hd[]!!
    else {sta[top=1]=u; a[u]=q[1].w; hd[u]=0; st=2;}
    for(int i=st;i<=tot;i++)
      {
    u=q[i].v; int lca=get_lca(u,sta[top]);
    a[u]=q[i].w; alsm=upt(alsm+a[u]);
    while(dfn[sta[top]]>dfn[lca])
      {
        if(dfn[sta[top-1]]>=lca)add(sta[top-1],sta[top]);
        else add(lca,sta[top]);
        top--;
      }
    if(sta[top]!=lca)
      { sta[++top]=lca; a[lca]=0; hd[lca]=0;}
    sta[++top]=u; hd[u]=0;
      }
    for(int i=top-1;i;i--)add(sta[i],sta[i+1]);
  }
  void dfs(int cr)
  {
    if(a[cr]){siz[cr]=1;vl[cr]=a[cr];v2[cr]=(ll)a[cr]*a[cr]%mod;}
    else siz[cr]=vl[cr]=v2[cr]=0;
    s2[cr]=0;
    for(int i=hd[cr],v;i;i=nxt[i])
      {
    dfs(v=to[i]); siz[cr]+=siz[v];
    vl[cr]=upt(vl[cr]+vl[v]); v2[cr]=upt(v2[cr]+v2[v]);
    s2[cr]=(s2[cr]+s2[v]+(ll)siz[v]*(dep[v]-dep[cr]))%mod;//dep
      }
    if(a[cr])ret=(ret+(ll)dep[cr]*a[cr]%mod*upt(alsm-a[cr]))%mod;
    ret=(ret-dep[cr]*((ll)vl[cr]*vl[cr]%mod-v2[cr]))%mod;
  }
  void dfsx(int cr)
  {
    if(a[cr])ret=(ret+(ll)upt(sm[cr]-s2[cr])*a[cr]*2)%mod;
    for(int i=hd[cr],v;i;i=nxt[i])
      {
    v=to[i];
    int tp=(s2[cr]-s2[v]-(ll)siz[v]*(dep[v]-dep[cr]))%mod;
    tp=(tp+(ll)(tot-siz[v])*(dep[v]-dep[cr]))%mod;//tot
    s2[v]=upt(s2[v]+tp); dfsx(v);
      }
  }
  void solve()
  {
    build(); dfs(1); ret=upt(ret<<1); dfsx(1);//*2
    ret=(ll)ret*(p-1)%mod*pw(p,mod-2)%mod;
    ans=upt(ans+ret); ret=0;//ret=0
  }
}
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
void init()
{
  ll d;
  for(int i=2;i<=n;i++)
    {
      if(!vis[i])pri[++cnt]=i,mnd[i]=i,dy[i]=cnt;
      for(int j=1;j<=cnt&&(d=(ll)i*pri[j])<=n;j++)
    { vis[d]=1;mnd[d]=pri[j];if(i%pri[j]==0)break;}
    }
}
void dfs(int cr,int fa)
{
  siz[cr]=1; dfn[cr]=++tim; dep[cr]=dep[fa]+1;
  pre[cr][0]=fa;
  for(int t=1,d=fa;(d=pre[d][t-1]);t++)
    pre[cr][t]=d;
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=fa)
      {
    dfs(v,cr); siz[cr]+=siz[v];
    sm[cr]=(upt(sm[cr]+sm[v])+siz[v]);
      }
}
void dfsx(int cr,int fa)
{
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=fa)
      {
    int tp=upt(sm[cr]-sm[v]-siz[v]);
    tp=upt(tp+n-siz[v]);
    sm[v]=upt(sm[v]+tp); dfsx(v,cr);
      }
}
int main()
{
  n=rdn(); init();
  for(int i=1;i<=n;i++)
    {
      a[i]=rdn(); int k=a[i];
      while(k>1)
    {
      int tp=1,d=mnd[k];
      while(mnd[k]==d)k/=d,tp*=d;
      vt[dy[d]].pb(Node(i,tp));//dy[]
    }
    }
  for(int i=1,u,v;i<n;i++)
    u=rdn(),v=rdn(),add(u,v),add(v,u);
  dfs(1,0); dfsx(1,0);
  for(int i=1;i<=cnt;i++)
    if(vt[i].size())
      {
    VT::tot=0;
    for(int j=0,lm=vt[i].size();j<lm;j++)
      VT::q[++VT::tot]=vt[i][j];
    VT::p=pri[i]; VT::solve();
      }
  ans=(ll)ans*pw((ll)n*(n-1)%mod,mod-2)%mod;
  printf("%d\n",ans);
  return 0;
}
View Code

  题解:https://blog.sengxian.com/solutions/cf-809e

  原来都不太了解这样形式的莫比乌斯反演:

    若 \( f(i)=\sum g(倍数) \) ,则 \( g(i)=\sum f(倍数)*\mu(倍率) \) 

  一般这里可以用容斥解决。就是从大到小枚举,求出 \( \sum g(倍数) \) ,把多余的 f( ) 减去即可。此时更大的 f( ) 已求出了。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
const int N=2e5+5,K=17,mod=1e9+7;
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}

int n,a[N],dy[N],hd[N],xnt,to[N<<1],nxt[N<<1];
int tim,dfn[N],dep[N],f[N];
int tot,L[N],R[N],q[N<<1],st[N<<1][K+5],lg[N<<1],bin[K+5];
int pri[N],phi[N],cnt; bool vis[N];
struct Node{
  int v,w;
  Node(int v=0,int w=0):v(v),w(w) {}
  bool operator< (const Node &b)const
  {return dfn[(*this).v]<dfn[b.v];}
};
namespace VT{
  Node q[N]; int tot,sta[N],top,ret,p;
  int a[N],sm,hd[N],xnt,to[N<<1],nxt[N<<1],vl[N],v2[N];
  void add(int x,int y)
  {to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
  int get_lca(int x,int y)
  {
    if(dfn[y]<dfn[x])swap(x,y);//
    if(R[y]<=R[x])return x;//else R[x]->L[y]
    int l=R[x], r=L[y], d=lg[r-l+1];
    if(dfn[st[l][d]]<dfn[st[r-bin[d]+1][d]])
      return st[l][d];
    return st[r-bin[d]+1][d];
  }
  void build()
  {
    xnt=sm=0; sort(q+1,q+tot+1); int st,u=q[1].v;
    if(u!=1){sta[top=1]=1; a[1]=0; hd[1]=0; st=1;}//hd[]!!
    else {sta[top=1]=u; a[u]=q[1].w; sm=a[u]; hd[u]=0; st=2;}
    for(int i=st;i<=tot;i++)
      {
    u=q[i].v; int lca=get_lca(u,sta[top]); bool fg=0;
    a[u]=q[i].w; sm=upt(sm+a[u]);
    while(dfn[sta[top]]>dfn[lca])
      {
        if(dfn[sta[top-1]]>=dfn[lca])//dfn[lca] not lca!!
          add(sta[top-1],sta[top]);
        else {hd[lca]=0;fg=1;add(lca,sta[top]);}//
        top--;
      }
    if(sta[top]!=lca)
      { sta[++top]=lca; a[lca]=0; if(!fg)hd[lca]=0;}
    sta[++top]=u; hd[u]=0;
      }
    for(int i=top-1;i;i--)add(sta[i],sta[i+1]);
  }
  void dfs(int cr)
  {
    vl[cr]=a[cr]; v2[cr]=(ll)a[cr]*a[cr]%mod;
    for(int i=hd[cr],v;i;i=nxt[i])
      {
    dfs(v=to[i]);
    ret=(ret-2ll*dep[cr]*vl[cr]%mod*vl[v])%mod;
    vl[cr]=upt(vl[cr]+vl[v]);
    v2[cr]=upt(v2[cr]+v2[v]);
      }
    ret=(ret+(ll)dep[cr]*a[cr]%mod*upt(sm-a[cr]))%mod;
  }
  void solve()
  {
    build(); dfs(1); ret=upt(ret<<1);//*2
    f[p]=upt(ret); ret=0;//ret=0
  }
}
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
void init()
{
  ll d; phi[1]=1;//
  for(int i=2;i<=n;i++)
    {
      if(!vis[i])pri[++cnt]=i,phi[i]=i-1;
      for(int j=1;j<=cnt&&(d=(ll)i*pri[j])<=n;j++)
    {
      vis[d]=1;
      if(i%pri[j]==0){phi[d]=(ll)phi[i]*pri[j]%mod;break;}
      phi[d]=(ll)phi[i]*phi[pri[j]]%mod;
    }
    }
}
void dfs(int cr,int fa)
{
  dfn[cr]=++tim; dep[cr]=dep[fa]+1;
  q[++tot]=cr; L[cr]=tot;
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=fa) dfs(v,cr),q[++tot]=cr;//every pass!!!
  R[cr]=tot;
}
void lca_ini()
{
  for(int i=2;i<=tot;i++)lg[i]=lg[i>>1]+1;
  bin[0]=1;
  for(int i=1;i<=lg[tot];i++)bin[i]=bin[i-1]<<1;
  for(int i=1;i<=tot;i++)st[i][0]=q[i];
  for(int t=1;t<=lg[tot];t++)
    for(int i=1;i+bin[t]-1<=tot;i++)
      {
    if(dfn[st[i][t-1]]<dfn[st[i+bin[t-1]][t-1]])
      st[i][t]=st[i][t-1];
    else st[i][t]=st[i+bin[t-1]][t-1];
      }
}
int main()
{
  n=rdn(); init();
  for(int i=1;i<=n;i++) a[i]=rdn(), dy[a[i]]=i;
  for(int i=1,u,v;i<n;i++)
    u=rdn(),v=rdn(),add(u,v),add(v,u);
  dfs(1,0); lca_ini();
  for(int i=1;i<=n;i++)
    {
      VT::tot=0; VT::p=i;
      for(int j=i;j<=n;j+=i)
    VT::q[++VT::tot]=Node(dy[j],phi[j]);
      VT::solve();
    }
  int ans=0;
  for(int i=n;i;i--)
    {
      for(int j=i+i;j<=n;j+=i)
    f[i]=upt(f[i]-f[j]);
      ans=(ans+(ll)f[i]*i%mod*pw(phi[i],mod-2))%mod;
    }
  ans=(ll)ans*pw((ll)n*(n-1)%mod,mod-2)%mod;
  printf("%d\n",ans);
  return 0;
}
View Code

 

posted on 2019-04-24 21:33  Narh  阅读(138)  评论(0编辑  收藏  举报

导航