Codeforces 990G GCD Counting 题解
设 \(k\) 的答案为 \(g(k)\),直接计算 \(g(k)\) 貌似很难,设 \(f(k)\) 为 \(k\mid\gcd(x,y)\) 的 \((x,y),x\leq y\) 个数。(这里定义 \(\gcd(x,y)\) 为 \(x\) 到 \(y\) 最短路径的点权 \(\gcd\))
可以莫比乌斯反演一下,有:
\[f(d)=\sum_{d|n}g(n)\Rightarrow g(d)=\sum_{d|n}\mu(\frac{n}{d})f(n)
\]
假如我们已经处理出了 \(f\),筛出 \(\mu\) 后就能在 \(\mathcal{O}(n\log n)\) 的复杂度内算出 \(g\)。
如何算 \(f\) ?注意到 \(\leq 2\times 10^5\) 的最大约数个数是 \(\leq 240\) 的,可以暴力找出所有约数,并在这些约数为编号的图中加入这条边。
对于每个数 \(i\) 为编号的图中统计 \(f_i\),用并查集,每次合并连通块时候 \(f_i\) 加上跨合并连通块的这条边左右两边点数的乘积。
总复杂度为 \(\mathcal{O}(n\log n+n\sqrt a)\)。
#include<iostream>
#include<cstdio>
#include<vector>
typedef long long ll;
template <typename T> T Max(T x, T y) { return x > y ? x : y; }
template <typename T> T Min(T x, T y) { return x < y ? x : y; }
template <typename T>
T& read(T& r) {
r = 0; bool w = 0; char ch = getchar();
while(ch < '0' || ch > '9') w = ch == '-' ? 1 : 0, ch = getchar();
while(ch >= '0' && ch <= '9') r = r * 10 + (ch ^ 48), ch = getchar();
return r = w ? -r : r;
}
inline int gcd(int x, int y) { return !y ? x : gcd(y, x % y); }
const int N = 200005;
int n, mx, a[N];
ll f[N];
std::vector<int>vec[N];
struct DSU {
int fa[N], siz[N];
int find(int x) { return fa[x] = fa[x] == x ? x : find(fa[x]); }
void merge(int t, int x, int y) {
int fx = find(x), fy = find(y);
if(fx == fy) return ;
f[t] += 1ll * siz[fx] * siz[fy];
fa[fx] = fy;
siz[fy] += siz[fx];
}
}dsu;
int prime[N], pct, mu[N];
int lu[N], lv[N];
bool vis[N];
void init() {
vis[1] = 1; mu[1] = 1;
for(int i = 2; i <= mx; ++i) {
if(!vis[i]) {
prime[++pct] = i;
mu[i] = -1;
}
for(int j = 1; j <= pct && i * prime[j] <= mx; ++j) {
vis[i * prime[j]] = 1;
if(i % prime[j] == 0) { mu[i * prime[j]] = 0; break; }
mu[i * prime[j]] = -mu[i];
}
}
}
int main() {
read(n);
for(int i = 1; i <= n; ++i) {
read(a[i]);
mx = Max(mx, a[i]);
for(int j = 1; j * j <= a[i]; ++j) {
if(a[i] % j) continue ;
++f[j];
if(j * j != a[i]) ++f[a[i]/j];
}
}
init();
for(int i = 1; i < n; ++i) {
read(lu[i]); read(lv[i]);
int g = gcd(a[lu[i]], a[lv[i]]);
for(int j = 1; j * j <= g; ++j)
if(g % j == 0) {
vec[j].push_back(i);
if(j * j != g) vec[g / j].push_back(i);
}
}
for(int i = 1; i <= mx; ++i) {
for(auto x : vec[i]) {
dsu.fa[lu[x]] = lu[x];
dsu.fa[lv[x]] = lv[x];
dsu.siz[lu[x]] = 1;
dsu.siz[lv[x]] = 1;
}
for(auto x : vec[i])
dsu.merge(i, lu[x], lv[x]);
}
for(int i = 1; i <= mx; ++i) {
ll ans = 0;
for(int j = i; j <= mx; j += i) ans += mu[j / i] * f[j];
if(ans) printf("%d %lld\n", i, ans);
}
return 0;
}