【2019北京集训测试赛(十三)】函树 虚树
题目大意:给你一颗$n$个节点的树,定义$d(x,y)=$点$x$到点$y$最短路上经过的边数。
求$\sum\limits_{i=1}^{n} \sum\limits_{j=1}^{n} \varphi(i\times j)\times d(i,j)$
答案对998244353$取模。
我们对这个式子做一些细微的处理,设最终的答案为$ans$:
$ans=\sum\limits_{i=1}^{n} \sum\limits_{j=1}^{n} \varphi(i\times j)\times d(i,j)$
$=\sum\limits_{i=1}^{n} \sum\limits_{j=1}^{n} \varphi(i)\varphi(j)\frac{gcd(i,j)}{\varphi(gcd(i,j))}\times d(i,j)$
我们设$F(d)=\sum\limits_{i=1}^{n} \sum\limits_{j=1,d|gcd(i,j)}^{n} \varphi(i)\varphi(j)\times d(i,j)$
那么,$ans=\sum\limits_{d=1}^{n} \frac{d}{\varphi(d)} \sum\limits_{p|d} F(p)\times G(\frac{d}{p})$
对于$G(x)$,设$x=\prod\limits_{i=1}^{k} p_i$,$p_i$是质数,$G(x)=(-1)^k$
我们考虑如何求$F(d)$。
显然,我们只需要把所有点权能被$d$整除的点找出来,建一棵虚树,统计每条虚树边两端的\$sum \varphi(i)$,把它们乘起来,再乘上虚树边边长即可。
由于点权等于编号,所以n棵虚树的总点数是$O(n\ln\ n)$级别的,单次构建虚树的复杂度是$O(size\log\ size)$的,所以并不会$T$掉。
然后就没有了,总复杂度是$O(n\log^2\ n)$的。
1 #include<bits/stdc++.h> 2 #define M 100005 3 #define MOD 998244353 4 #define L long long 5 using namespace std; 6 7 int pri[M]={0},b[M]={0},phi[M]={0},zf[M]={0},Use=0; 8 void init(){ 9 phi[1]=1; zf[1]=1; 10 for(int i=2;i<M;i++){ 11 if(!b[i]) pri[++Use]=i,phi[i]=i-1,zf[i]=-1; 12 for(int j=1;j<=Use&&i*pri[j]<M;j++){ 13 b[i*pri[j]]=1; zf[i*pri[j]]=-zf[i]; 14 if(i%pri[j]==0) {phi[i*pri[j]]=phi[i]*pri[j]; break;} 15 phi[i*pri[j]]=phi[i]*(pri[j]-1); 16 } 17 } 18 } 19 20 L pow_mod(L x,L k){L ans=1; for(;k;k>>=1,x=x*x%MOD) if(k&1) ans=ans*x%MOD; return ans;} 21 vector<int> G[M]; 22 23 struct edge{int u,v,next;}e[M*2]={0}; int head[M]={0},use=0; 24 void add(int x,int y,int z){use++;e[use].u=y;e[use].v=z;e[use].next=head[x];head[x]=use;} 25 int n,a[M]={0}; 26 27 int dep[M]={0},dfn[M]={0},low[M]={0},f[M][20]={0},t=0; 28 void dfs(int x,int fa){ 29 dep[x]=dep[fa]+1; dfn[x]=++t; f[x][0]=fa; 30 for(int i=1;i<20;i++) f[x][i]=f[f[x][i-1]][i-1]; 31 for(int i=0;i<G[x].size();i++) if(G[x][i]!=fa) dfs(G[x][i],x); 32 low[x]=t; 33 } 34 int getlca(int x,int y){ 35 if(dep[x]<dep[y]) swap(x,y); int cha=dep[x]-dep[y]; 36 for(int i=19;~i;i--) if((1<<i)&cha) x=f[x][i]; 37 if(x==y) return x; 38 for(int i=19;~i;i--) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i]; 39 return f[x][0]; 40 } 41 42 vector<int> D[M]; L F[M]={0}; 43 44 bool cmp(int x,int y){return dfn[x]<dfn[y];} 45 int point[M]={0},stk[M]={0},is[M]={0},pcnt=0,cnt=0,nowt=0; 46 void build(){ 47 pcnt=cnt; int siz=0; nowt=0; 48 sort(point+1,point+cnt+1,cmp); 49 for(int i=1;i<=cnt;i++){ 50 int last=0; 51 while(siz&&getlca(stk[siz],point[i])!=stk[siz]) last=stk[siz],stk[siz--]=0; 52 if(last){ 53 int lca=getlca(last,point[i]); 54 if(lca!=stk[siz]){ 55 stk[++siz]=lca; 56 point[++pcnt]=lca; 57 is[lca]=0; 58 } 59 } 60 stk[++siz]=point[i]; is[point[i]]=1; 61 } 62 sort(point+1,point+pcnt+1,cmp); 63 while(siz) stk[siz--]=0; 64 } 65 L sumphi[M]={0},sum=0; 66 int dfs(int x){ 67 if(is[x]) sumphi[x]=phi[a[x]]; else sumphi[x]=0; 68 int v; nowt++; 69 while(dfn[v=point[nowt]]<=low[x]&&nowt<=pcnt){ 70 dfs(v); 71 sumphi[x]=(sumphi[x]+sumphi[v])%MOD; 72 } 73 } 74 void getans(int x,L fsum){ 75 int v; nowt++; 76 while(dfn[v=point[nowt]]<=low[x]&&nowt<=pcnt){ 77 sum=(sum+1LL*(dep[v]-dep[x])*sumphi[v]%MOD*(fsum+sumphi[x]-sumphi[v]+MOD))%MOD; 78 getans(v,(fsum+sumphi[x]-sumphi[v])%MOD); 79 } 80 } 81 void solve(int x){ 82 while(pcnt)point[pcnt--]=0; cnt=sum=0; 83 for(int i=0;i<D[x].size();i++){ 84 point[++cnt]=D[x][i]; 85 } 86 build(); 87 nowt=1; dfs(point[1]); 88 nowt=1; getans(point[1],0); 89 F[x]=sum; 90 } 91 92 int main(){ 93 // freopen("in.txt","r",stdin); 94 // freopen("out.txt","w",stdout); 95 init(); 96 scanf("%d",&n); 97 for(int i=1;i<=n;i++){ 98 a[i]=i; //scanf("%d",a+i); 99 for(int j=1;j*j<=a[i];j++) if(a[i]%j==0){ 100 D[j].push_back(i); 101 if(j*j!=a[i]) D[a[i]/j].push_back(i); 102 } 103 } 104 for(int i=1;i<n;i++){ 105 int x,y; scanf("%d%d",&x,&y); 106 G[x].push_back(y); G[y].push_back(x); 107 } 108 dfs(1,0); 109 for(int i=1;i<=n;i++) solve(i); 110 for(int i=n;i;i--){ 111 for(int j=i*2;j<=n;j+=i) 112 F[i]=(F[i]-F[j]+MOD)%MOD; 113 } 114 L ans=0; 115 for(int d=1;d<=n;d++) 116 ans=(ans+F[d]*d%MOD*pow_mod(phi[d],MOD-2))%MOD; 117 cout<<ans*2%MOD<<endl; 118 //cout<<ans*2*pow_mod(1LL*n*(n-1)%MOD,MOD-2)%MOD<<endl; 119 }