【题解】[ABC248G] GCD cost on the tree
欢迎收看古明地恋的心跳大冒险
思路
容斥 + dp.
\(\gcd\) 相关,考虑 \(\mu\) 反演或者 \(\varphi\) 反演。
本质上都和容斥差不多,不如直接一步到位考虑容斥。
把权值拆成 \(\gcd\) 和对应的方案数两部分,考虑求对应的方案数。
令 \(f[v]\) 表示 \(\gcd = v\) 时的合法路径总数,\(g[v]\) 表示 \(v \mid \gcd\) 时的合法路径总数,显然有 \(f[v] = g[v] - \sum\limits_{i = 2, iv \leq \max(a_i)} f[iv]\).
于是我们只需要求出 \(g\).
一个想法是考虑保留所有为 \(v\) 倍数的边,然后在形成的连通块中做一个 dp 统计答案。
类似长剖一样,先令 \(res[u]\) 表示 \(u\) 子树中所有非孤点路径的总点数。
然后考虑每个点对于答案的贡献:长度乘以贡献次数。用 \(len[u]\) 表示 \(u\) 子树中所有一端为 \(u\) 的非孤点路径总点数,\(cnt[u]\) 表示所有一端为 \(u\) 的非孤点路径数量。
转移比较好推:
考虑所有 \(u\) 的子结点 \(v\).
分别算出两边的贡献累加:\(res[u] = res[u] + cnt[u] \times len[v] + len[u] \times cnt[v]\).
与 \(v\) 相连的路径可以在顶端加上 \(u\),此时这 \(cnt[v]\) 条路径的长度都会加一,于是有 \(len[u] = len[u] + len[v] + cnt[v]\).
同理可得 \(cnt[u] = cnt[u] + cnt[v]\).
考虑枚举 \(\gcd\) 的时候给合法的结点打标记,这样每次只需要遍历连通块 dp。每个结点只会被标记 \(\sigma(a_i)\) 次,所以时间复杂度是 \(O(V \log V + n \max(\sigma(a_i)))\).
代码
#include <cstdio>
#include <vector>
using namespace std;
#define il inline
const int maxn = 1e5 + 5;
const int mod = 998244353;
int n, m;
int a[maxn], f[maxn], g[maxn];
int len[maxn], cnt[maxn], res[maxn];
bool vis[maxn];
vector<int> gr[maxn], nd[maxn];
il int max(const int &a, const int &b) { return (a >= b ? a : b); }
il int read()
{
int res = 0;
char ch = getchar();
while ((ch < '0') || (ch > '9')) ch = getchar();
while ((ch >= '0') && (ch <= '9')) res = res * 10 + ch - '0', ch = getchar();
return res;
}
void dfs(int u, int val)
{
cnt[u] = len[u] = 1, res[u] = 0, vis[u] = false;
for (int v : gr[u])
{
if (vis[v])
{
dfs(v, val);
res[u] = (res[u] + 1ll * cnt[u] * len[v] % mod + 1ll * len[u] * cnt[v] % mod) % mod;
len[u] = (1ll * len[u] + len[v] + cnt[v]) % mod;
cnt[u] = (cnt[u] + cnt[v]) % mod;
}
}
g[val] = (g[val] + res[u]) % mod;
}
int main()
{
n = read();
for (int i = 1; i <= n; i++) a[i] = read(), m = max(m, a[i]);
for (int i = 1, u, v; i <= n - 1; i++)
{
u = read(), v = read();
gr[u].push_back(v), gr[v].push_back(u);
}
for (int i = 1; i <= n; i++)
{
for (int j = 1; j * j <= a[i]; j++)
{
if (a[i] % j == 0)
{
nd[j].push_back(i);
if (j * j != a[i]) nd[a[i] / j].push_back(i);
}
}
}
for (int i = m; i; i--)
{
if (nd[i].size())
{
for (int v : nd[i]) vis[v] = true;
for (int v : nd[i])
if (vis[v]) dfs(v, i);
f[i] = g[i];
for (int j = 2 * i; j <= m; j += i) f[i] = (f[i] - f[j] + mod) % mod;
}
}
int ans = 0;
for (int i = 1; i <= m; i++) ans = (ans + 1ll * i * f[i] % mod) % mod;
printf("%d\n", ans);
return 0;
}