CF990G GCD Counting 点分治+容斥+暴力
只想出来 $O(nlogn\times 160)$ 的复杂度,没想到还能过~
Code:
#include <cstdio> #include <vector> #include <algorithm> #define N 200004 #define ll long long #define setIO(s) freopen(s".in","r",stdin) using namespace std; int n; vector<int>v[N]; ll answer[N],anss[N]; int prime[N],is[N],tot; int val[N],hd[N],to[N<<1],nex[N<<1],edges; int size[N],vis[N],mx[N],root,sn; int tl,tmp[N],viss[N]; ll f[N],g[N]; void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } void init() { int i,j; for(i=2;i<N;++i) { if(!is[i]) prime[++tot]=i; for(j=1;j<=tot&&prime[j]*i<N;++j) { is[prime[j]*i]=1; if(i%prime[j]==0) break; } } for(i=1;i<N;++i) for(j=i;j<N;j+=i) v[j].push_back(i); } void getroot(int u,int ff) { size[u]=1,mx[u]=0; for(int i=hd[u];i;i=nex[i]) if(to[i]!=ff&&!vis[to[i]]) getroot(to[i],u),size[u]+=size[to[i]],mx[u]=max(mx[u],size[to[i]]); mx[u]=max(mx[u],sn-size[u]); if(mx[u]<mx[root]) root=u; } void dfs(int u,int ff,int num) { num=__gcd(num,val[u]); tmp[++tl]=num; for(int i=hd[u];i;i=nex[i]) if(to[i]!=ff&&!vis[to[i]]) dfs(to[i],u,num); } void calc(int u) { int i,j; tl=0; for(i=0;i<v[val[u]].size();++i) ++f[v[val[u]][i]],++anss[v[val[u]][i]]; for(i=hd[u];i;i=nex[i]) { if(vis[to[i]]) continue; int re=tl+1; dfs(to[i],u,val[u]); for(j=re;j<=tl;++j) { int a=tmp[j]; for(int k=0;k<v[a].size();++k) ++g[v[a][k]]; } for(j=re;j<=tl;++j) { int a=tmp[j]; for(int k=0;k<v[a].size();++k) if(!viss[v[a][k]]) { anss[v[a][k]]+=1ll*f[v[a][k]]*g[v[a][k]],viss[v[a][k]]=1; f[v[a][k]]+=g[v[a][k]]; } } for(j=re;j<=tl;++j) { int a=tmp[j]; for(int k=0;k<v[a].size();++k) { viss[v[a][k]]=0,g[v[a][k]]=0; } } } for(i=0;i<v[val[u]].size();++i) f[v[val[u]][i]]=0; for(i=1;i<=tl;++i) { int a=tmp[i]; for(j=0;j<v[a].size();++j) f[v[a][j]]=g[v[a][j]]=viss[v[a][j]]=0; } } void solve(int u) { vis[u]=1,calc(u); for(int i=hd[u];i;i=nex[i]) if(!vis[to[i]]) root=0,sn=size[to[i]],getroot(to[i],u),solve(root); } int main() { init(); int i,j,Mx=0; // setIO("input"); scanf("%d",&n); for(i=1;i<=n;++i) scanf("%d",&val[i]),Mx=max(Mx,val[i]); for(i=1;i<n;++i) { int a,b; scanf("%d%d",&a,&b),add(a,b),add(b,a); } mx[root=0]=sn=n,getroot(1,0),solve(root); for(i=Mx;i>=1;--i) { answer[i]=anss[i]; for(j=i+i;j<=Mx;j+=i) answer[i]-=answer[j]; } for(i=1;i<=Mx;++i) if(answer[i]) printf("%d %lld\n",i,answer[i]); return 0; }