CF809E Surprise me!(虚树+莫比乌斯反演)
CF809E Surprise me!(虚树+莫比乌斯反演)
题目大意
给定一棵 \(n\) 个节点的树,每个点有一个权值 \(a[i]\) ,保证 \(a[i]\) 是一个 \(1..n\) 的排列。
求
\[\frac{1}{n(n-1)}\sum_{i=1}^n\sum_{j=1}^n\varphi(a_i*a_j)·dist(i,j)
\]
其中, \(\varphi(x)\) 是欧拉函数, \(dist(i,j)\) 表示 \(i,j\) 两个节点在树上的距离。
数据范围
$ 2 \le n \le 2·10^{5} $
解题思路
套路题,需要一定的熟练度
首先想到当 \(a_i~a_j\) 互质时 \(\varphi(a_i * a_j) = \varphi(a_i) * \varphi(a_j)\),如果不互质发现直接乘会多乘一些,不难发现是
\[\prod_{P |a_i,P|a_j} (1 - \frac 1p) \tag{1}
\]
这些质数显然构成了 \(gcd(a_i, a_j)\) 的质因子,又因为 \(\varphi(x) = \prod_{P | x} (1 - \frac 1p) * x\),所以有
\[\varphi(a_i * a_j) = \frac{\varphi(a_i) * \varphi(a_j) * gcd(a_i, a_j)}{\varphi(gcd(a_i, a_j))}
\]
答案转化为
\[\frac{1}{n(n-1)}\sum_{d=1}^n \frac d{\varphi(d)} \sum_{i=1}^n\sum_{j=1}^n [gcd(a_i, a_j)==d]\varphi(a_i) *\varphi(a_j)·dist(i,j)
\]
套路的发现等于 d 的条件不好做,试试 \([d~|~gcd(a_i, a_j)]\)
设
\[F(x) = \sum_{i=1}^n\sum_{j=1}^n [gcd(a_i, a_j)==d]\varphi(a_i) *\varphi(a_j)·dist(i,j)\\
f(x) = \sum_{i=1}^n\sum_{j=1}^n [d~|~gcd(a_i, a_j)]\varphi(a_i) *\varphi(a_j)·dist(i,j)
\]
则有
\[f(x) = \sum_{x|d}F(d)
\]
莫比乌斯反演得
\[F(x) = \sum_{x|d} f(d) * \mu(\frac dx)
\]
所以我们统计 \(f(x)\) 就可以了
\[f(x) = \sum_{x |a_i} \sum_{x|a_j}\varphi(a_i)*\varphi(a_j)*(dep[i] + dep[j]-2*dep[lca])\\
=2 * (\sum_{x|a_i}\varphi(a_i)*dep[i])*\sum_{x|a_i}\varphi(a_j)-2*\sum_{x |a_i} \sum_{x|a_j}\varphi(a_i)*\varphi(a_j)*dep[lca]
\]
枚举公因数 d,\(a_i, a_j\) 一定为 d 的倍数,因为 a 是一个排列,根据调和级数,总个数是 \(\Theta(nlogn)\) 级别的。
所以可以对每个 d 都建立一颗其倍数的虚树,第一部分答案直接记录一下和即可,第二部分答案在虚树上统计答案即可
淡黄的长裙,蓬松的代码
using namespace std;
template <typename T>
void read(T &x) {
x = 0; bool f = 0;
char c = getchar();
for (;!isdigit(c);c=getchar()) if (c=='-') f=1;
for (;isdigit(c);c=getchar()) x=x*10+(c^48);
if (f) x=-x;
}
const int P = 1e9+7;
const int N = 400500;
ll fpw(ll x, ll mi) {
ll res = 1;
while (mi) {
if (mi & 1) res = res * x % P;
x = x * x % P; mi >>= 1;
}
return res;
}
int dep[N], f[N], Top[N], son[N], siz[N], a[N], n;
int dfn[N], h[N], ne[N<<1], to[N<<1], num, tot;
inline void add(int x, int y) {
ne[++tot] = h[x], to[h[x] = tot] = y;
}
void dfs1(int x, int fa) {
dep[x] = dep[f[x] = fa] + (siz[x] = 1), dfn[x] = ++num;
for (int i = h[x]; i ; i = ne[i]) {
int y = to[i]; if (y == fa) continue;
dfs1(y, x); siz[x] += siz[y];
if (siz[son[x]] < siz[y]) son[x] = y;
}
}
void dfs2(int x, int topf) {
Top[x] = topf;
if (!son[x]) return; dfs2(son[x], topf);
for (int i = h[x], y; i; i = ne[i])
if ((y = to[i]) != f[x] && y != son[x]) dfs2(y, y);
}
int lca(int x, int y) {
while (Top[x] != Top[y]) {
if (dep[Top[x]] < dep[Top[y]]) swap(x, y);
x = f[Top[x]];
}
return dep[x] < dep[y] ? x : y;
}
int prime[N], phi[N], mu[N], e[N], cnt;
void prework(void) {
phi[1] = mu[1] = 1;
for (int i = 2;i <= n; i++) {
if (!e[i]) prime[++cnt] = e[i] = i, phi[i] = i - 1, mu[i] = -1;
for (int j = 1;j <= cnt; j++) {
int t = prime[j] * i;
if (t > n) break; e[t] = prime[j];
if (prime[j] == e[i]) {
phi[t] = phi[i] * prime[j];
break;
}
phi[t] = phi[i] * (prime[j] - 1);
mu[t] = -mu[i];
}
}
}
ll g[N], G[N], res;
int tmp[N], p[N], st[N], top;
vector<int> v[N];
inline void add_e(int x, int y) { v[x].push_back(y); }
void insert(int x) {
if (top == 1) return st[++top] = x, void();
int Lca = lca(x, st[top]);
if (Lca == st[top]) return st[++top] = x, void();
while (top > 1 && dfn[st[top-1]] >= dfn[Lca])
add_e(st[top-1], st[top]), top--;
if (Lca != st[top]) add_e(Lca, st[top]), st[top] = Lca;
st[++top] = x;
}
bool cmp(int x, int y) { return dfn[x] < dfn[y]; }
ll sum[N], ans;
void dp(int x, int k) {
if (a[x] % k == 0) {
ans = (ans + (ll)phi[a[x]] * phi[a[x]] * dep[x]) % P;
sum[x] = phi[a[x]];
}
else sum[x] = 0;
for (auto y: v[x]) {
dp(y, k); (ans += 2 * sum[x] * sum[y] % P * dep[x]) %= P;
(sum[x] += sum[y]) %= P;
}
v[x].clear();
}
int main() {
read(n); prework();
for (int i = 1;i <= n; i++) read(a[i]), p[a[i]] = i;
for (int i = 1, x, y; i < n; i++)
read(x), read(y), add(x, y), add(y, x);
dfs1(1, 0), dfs2(1, 1);
for (int d = 1; d * 2 <= n ; d++) {
st[top = 1] = 1, res = ans = 0;
ll sum = 0, sphi = 0;
for (int j = d;j <= n; j += d)
tmp[++res] = p[j], sum = (sum + (ll)phi[j] * dep[p[j]]) % P, sphi += phi[j];
sort(tmp + 1, tmp + res + 1, cmp);
if (tmp[1] != 1) insert(tmp[1]);
for (int i = 2; i <= res; i++) insert(tmp[i]);
while (top > 1) add_e(st[top-1], st[top]), top--;
dp(1, d); g[d] = (2 * (sphi % P) * sum - 2 * ans) % P;
if (g[d] < 0) g[d] += P;
}
for (int i = 1;i <= n; i++) {
for (int j = i;j <= n; j += i)
G[i] += g[j] * mu[j / i], G[i] = (G[i] + P) % P;
}
ans = 0;
for (int i = 1;i <= n; i++) ans = (ans + i * fpw(phi[i], P - 2) % P * G[i]) % P;
ans = (ans + P) % P;
printf ("%lld\n", ans * fpw((ll)n * (n - 1) % P, P - 2) % P);
return 0;
}