2019.4.24 一题(CF 809E)——推式子+虚树
题目:http://codeforces.com/contest/809/problem/E
原来以为可以每个质因子分开给答案贡献。
大概就是把有这个质因子的数都拿出来建虚树,这样虚树的总点数是 nlogn 的。
定义 b[ i ] 表示 i 点原来的权值分解出的 pt ,其中 p 是目前在做的质因子。那么要求虚树里的 \( \sum\limits_{u} \sum\limits_{v} dis(u,v)*b[u]*b[v] \) 。最后乘 \( \frac{p-1}{p} \) 即可。
把 dis( u , v ) 拆成 dep[ u ] + dep[ v ] - 2*dep[ lca ] ,就是每个点贡献 \( dep[cr]*b[cr]*\sum\limits_{i!=cr}b[i] \) ,\( -2*dep[cr]*( (\sum\limits_{i \in tree_cr}b[i])^2 - \sum\limits_{i \in tree_cr}b[i]^2 ) \)
还要考虑当前质因子的 “虚树上的点与虚树外的点” 的贡献,就是 \( b[cr]*\sum\limits_{i}dis(i,cr) \) 。要求 “该点与虚树外的点的距离和” ,可以用 “该点与所有点的距离和” - “该点与虚树上点的距离和” , 换根 DP 一番即可。
然后把各种质因子的贡献加起来。
试了一下自己造的样例:
4
1 2 3 4
1 2
1 3
3 4
发现可以。也没试试题面的样例,就开始写。写完发现是不对的。
#include<cstdio> #include<cstring> #include<algorithm> #include<vector> #define ll long long #define pb push_back using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } const int N=2e5+5,K=17,mod=1e9+7; int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;} int pw(int x,int k) {int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;} int n,a[N],hd[N],xnt,to[N<<1],nxt[N<<1],ans; int tim,dfn[N],dep[N],pre[N][K+5],siz[N],sm[N]; int pri[N],mnd[N],cnt,dy[N]; bool vis[N]; struct Node{ int v,w; Node(int v=0,int w=0):v(v),w(w) {} bool operator< (const Node &b)const {return dfn[(*this).v]<dfn[b.v];} }; vector<Node> vt[N]; namespace VT{ Node q[N]; int tot,sta[N],top,ret,p; int a[N],hd[N],xnt,to[N<<1],nxt[N<<1]; int siz[N],s2[N],alsm,vl[N],v2[N]; void add(int x,int y) {to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} int get_lca(int x,int y) { if(dep[x]<dep[y])swap(x,y); for(int t=K;t>=0;t--) if(dep[pre[x][t]]>=dep[y])x=pre[x][t]; if(x==y)return x; for(int t=K;t>=0;t--) if(pre[x][t]!=pre[y][t]) x=pre[x][t], y=pre[y][t]; return pre[x][0]; } void build() { xnt=alsm=0; sort(q+1,q+tot+1); int st,u=q[1].v; if(u!=1){sta[top=1]=1; a[1]=0; hd[1]=0; st=1;}//hd[]!! else {sta[top=1]=u; a[u]=q[1].w; hd[u]=0; st=2;} for(int i=st;i<=tot;i++) { u=q[i].v; int lca=get_lca(u,sta[top]); a[u]=q[i].w; alsm=upt(alsm+a[u]); while(dfn[sta[top]]>dfn[lca]) { if(dfn[sta[top-1]]>=lca)add(sta[top-1],sta[top]); else add(lca,sta[top]); top--; } if(sta[top]!=lca) { sta[++top]=lca; a[lca]=0; hd[lca]=0;} sta[++top]=u; hd[u]=0; } for(int i=top-1;i;i--)add(sta[i],sta[i+1]); } void dfs(int cr) { if(a[cr]){siz[cr]=1;vl[cr]=a[cr];v2[cr]=(ll)a[cr]*a[cr]%mod;} else siz[cr]=vl[cr]=v2[cr]=0; s2[cr]=0; for(int i=hd[cr],v;i;i=nxt[i]) { dfs(v=to[i]); siz[cr]+=siz[v]; vl[cr]=upt(vl[cr]+vl[v]); v2[cr]=upt(v2[cr]+v2[v]); s2[cr]=(s2[cr]+s2[v]+(ll)siz[v]*(dep[v]-dep[cr]))%mod;//dep } if(a[cr])ret=(ret+(ll)dep[cr]*a[cr]%mod*upt(alsm-a[cr]))%mod; ret=(ret-dep[cr]*((ll)vl[cr]*vl[cr]%mod-v2[cr]))%mod; } void dfsx(int cr) { if(a[cr])ret=(ret+(ll)upt(sm[cr]-s2[cr])*a[cr]*2)%mod; for(int i=hd[cr],v;i;i=nxt[i]) { v=to[i]; int tp=(s2[cr]-s2[v]-(ll)siz[v]*(dep[v]-dep[cr]))%mod; tp=(tp+(ll)(tot-siz[v])*(dep[v]-dep[cr]))%mod;//tot s2[v]=upt(s2[v]+tp); dfsx(v); } } void solve() { build(); dfs(1); ret=upt(ret<<1); dfsx(1);//*2 ret=(ll)ret*(p-1)%mod*pw(p,mod-2)%mod; ans=upt(ans+ret); ret=0;//ret=0 } } void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} void init() { ll d; for(int i=2;i<=n;i++) { if(!vis[i])pri[++cnt]=i,mnd[i]=i,dy[i]=cnt; for(int j=1;j<=cnt&&(d=(ll)i*pri[j])<=n;j++) { vis[d]=1;mnd[d]=pri[j];if(i%pri[j]==0)break;} } } void dfs(int cr,int fa) { siz[cr]=1; dfn[cr]=++tim; dep[cr]=dep[fa]+1; pre[cr][0]=fa; for(int t=1,d=fa;(d=pre[d][t-1]);t++) pre[cr][t]=d; for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { dfs(v,cr); siz[cr]+=siz[v]; sm[cr]=(upt(sm[cr]+sm[v])+siz[v]); } } void dfsx(int cr,int fa) { for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { int tp=upt(sm[cr]-sm[v]-siz[v]); tp=upt(tp+n-siz[v]); sm[v]=upt(sm[v]+tp); dfsx(v,cr); } } int main() { n=rdn(); init(); for(int i=1;i<=n;i++) { a[i]=rdn(); int k=a[i]; while(k>1) { int tp=1,d=mnd[k]; while(mnd[k]==d)k/=d,tp*=d; vt[dy[d]].pb(Node(i,tp));//dy[] } } for(int i=1,u,v;i<n;i++) u=rdn(),v=rdn(),add(u,v),add(v,u); dfs(1,0); dfsx(1,0); for(int i=1;i<=cnt;i++) if(vt[i].size()) { VT::tot=0; for(int j=0,lm=vt[i].size();j<lm;j++) VT::q[++VT::tot]=vt[i][j]; VT::p=pri[i]; VT::solve(); } ans=(ll)ans*pw((ll)n*(n-1)%mod,mod-2)%mod; printf("%d\n",ans); return 0; }
题解:https://blog.sengxian.com/solutions/cf-809e
原来都不太了解这样形式的莫比乌斯反演:
若 \( f(i)=\sum g(倍数) \) ,则 \( g(i)=\sum f(倍数)*\mu(倍率) \)
一般这里可以用容斥解决。就是从大到小枚举,求出 \( \sum g(倍数) \) ,把多余的 f( ) 减去即可。此时更大的 f( ) 已求出了。
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } const int N=2e5+5,K=17,mod=1e9+7; int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;} int pw(int x,int k) {int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;} int n,a[N],dy[N],hd[N],xnt,to[N<<1],nxt[N<<1]; int tim,dfn[N],dep[N],f[N]; int tot,L[N],R[N],q[N<<1],st[N<<1][K+5],lg[N<<1],bin[K+5]; int pri[N],phi[N],cnt; bool vis[N]; struct Node{ int v,w; Node(int v=0,int w=0):v(v),w(w) {} bool operator< (const Node &b)const {return dfn[(*this).v]<dfn[b.v];} }; namespace VT{ Node q[N]; int tot,sta[N],top,ret,p; int a[N],sm,hd[N],xnt,to[N<<1],nxt[N<<1],vl[N],v2[N]; void add(int x,int y) {to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} int get_lca(int x,int y) { if(dfn[y]<dfn[x])swap(x,y);// if(R[y]<=R[x])return x;//else R[x]->L[y] int l=R[x], r=L[y], d=lg[r-l+1]; if(dfn[st[l][d]]<dfn[st[r-bin[d]+1][d]]) return st[l][d]; return st[r-bin[d]+1][d]; } void build() { xnt=sm=0; sort(q+1,q+tot+1); int st,u=q[1].v; if(u!=1){sta[top=1]=1; a[1]=0; hd[1]=0; st=1;}//hd[]!! else {sta[top=1]=u; a[u]=q[1].w; sm=a[u]; hd[u]=0; st=2;} for(int i=st;i<=tot;i++) { u=q[i].v; int lca=get_lca(u,sta[top]); bool fg=0; a[u]=q[i].w; sm=upt(sm+a[u]); while(dfn[sta[top]]>dfn[lca]) { if(dfn[sta[top-1]]>=dfn[lca])//dfn[lca] not lca!! add(sta[top-1],sta[top]); else {hd[lca]=0;fg=1;add(lca,sta[top]);}// top--; } if(sta[top]!=lca) { sta[++top]=lca; a[lca]=0; if(!fg)hd[lca]=0;} sta[++top]=u; hd[u]=0; } for(int i=top-1;i;i--)add(sta[i],sta[i+1]); } void dfs(int cr) { vl[cr]=a[cr]; v2[cr]=(ll)a[cr]*a[cr]%mod; for(int i=hd[cr],v;i;i=nxt[i]) { dfs(v=to[i]); ret=(ret-2ll*dep[cr]*vl[cr]%mod*vl[v])%mod; vl[cr]=upt(vl[cr]+vl[v]); v2[cr]=upt(v2[cr]+v2[v]); } ret=(ret+(ll)dep[cr]*a[cr]%mod*upt(sm-a[cr]))%mod; } void solve() { build(); dfs(1); ret=upt(ret<<1);//*2 f[p]=upt(ret); ret=0;//ret=0 } } void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} void init() { ll d; phi[1]=1;// for(int i=2;i<=n;i++) { if(!vis[i])pri[++cnt]=i,phi[i]=i-1; for(int j=1;j<=cnt&&(d=(ll)i*pri[j])<=n;j++) { vis[d]=1; if(i%pri[j]==0){phi[d]=(ll)phi[i]*pri[j]%mod;break;} phi[d]=(ll)phi[i]*phi[pri[j]]%mod; } } } void dfs(int cr,int fa) { dfn[cr]=++tim; dep[cr]=dep[fa]+1; q[++tot]=cr; L[cr]=tot; for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) dfs(v,cr),q[++tot]=cr;//every pass!!! R[cr]=tot; } void lca_ini() { for(int i=2;i<=tot;i++)lg[i]=lg[i>>1]+1; bin[0]=1; for(int i=1;i<=lg[tot];i++)bin[i]=bin[i-1]<<1; for(int i=1;i<=tot;i++)st[i][0]=q[i]; for(int t=1;t<=lg[tot];t++) for(int i=1;i+bin[t]-1<=tot;i++) { if(dfn[st[i][t-1]]<dfn[st[i+bin[t-1]][t-1]]) st[i][t]=st[i][t-1]; else st[i][t]=st[i+bin[t-1]][t-1]; } } int main() { n=rdn(); init(); for(int i=1;i<=n;i++) a[i]=rdn(), dy[a[i]]=i; for(int i=1,u,v;i<n;i++) u=rdn(),v=rdn(),add(u,v),add(v,u); dfs(1,0); lca_ini(); for(int i=1;i<=n;i++) { VT::tot=0; VT::p=i; for(int j=i;j<=n;j+=i) VT::q[++VT::tot]=Node(dy[j],phi[j]); VT::solve(); } int ans=0; for(int i=n;i;i--) { for(int j=i+i;j<=n;j+=i) f[i]=upt(f[i]-f[j]); ans=(ans+(ll)f[i]*i%mod*pw(phi[i],mod-2))%mod; } ans=(ll)ans*pw((ll)n*(n-1)%mod,mod-2)%mod; printf("%d\n",ans); return 0; }