Loading

UOJ #33 树上 GCD

套路地,对每个 \(g \in [1, n - 1]\) 求出有多少对无序对 \((u, v)\) 满足 \(g|f(u, v)\),然后就可以用一个倒序枚举的 \(\mathcal O(n \log n)\) 的容斥求出答案。

考虑点分治,对每个经过分治中心 \(c\)\(u \rightsquigarrow \text{lca}(u, v) \rightsquigarrow v\) 只需分两种情况讨论(为了方便表述,记 \(\mathcal T_u\) 表示以 \(u\) 为根的子树):

  • \(u, v \in \mathcal T_c\)

    此时显然 \(\text{lca}(u, v) = c\),开个桶做一下就好。

    具体地,维护 \(cnt_i = \sum\limits_{u \in \mathcal T_c}[d(c, u) = i], c_i = \sum\limits_{i | j}cnt_j\),然后合并的时候枚举 \(g\) 统计贡献即可,单层时间复杂度是 \(\mathcal O(n \log n)\)

  • \(u, v\) 一个在 \(\mathcal T_c\) 内,一个在 \(\mathcal T_c\) 外(可以钦定此时 \(u \in \mathcal T_c, v \notin \mathcal T_c\)

    同样是维护 \(cnt_i = \sum\limits_{u \in \mathcal T_c}[d(c, u) = i]\),然后对从分治中心到当前子树的根节点路径上的点维护一个同样的桶 \(cnt'\),合并时枚举 \(g\),还要做在 \(cnt\) 上查询某个下标间隔为 \(g\) 的子序列和,预处理 \(g \le H\) 的部分就能做到单层 \(\mathcal O(n \sqrt n)\)

    关于时间复杂度的计算:记 \(H = \max\limits_{u \in \mathcal T_c}d(c, u)\)。对 \(g > \sqrt H\),单次查询时间复杂度 \(\mathcal O(\sqrt H)\);对 \(g \le \sqrt H\),对每个 \(g\) 恰好预处理所有 \(g\) 个不同的子序列和,时间复杂度 \(\mathcal O(H \sqrt H)\),此后查询 \(\mathcal O(1)\)。总查询次数 \(\mathcal O(n)\),故时间复杂度为 \(\mathcal O(n \sqrt H)\),也即单层 \(\mathcal O(n \sqrt n)\)。如果仔细算了时间复杂度,自然会发现常数非常小。

然后我们知道了单层时间复杂度是 \(\mathcal O(n \sqrt n)\),点分治的时间复杂度是 \(\mathcal O(n \log n)\),根据主定理,取大的单层时间复杂度,总体时间复杂度是 \(\mathcal O(n \sqrt n)\)

代码:

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

constexpr int N = 2e5 + 10, SN = 510;

int n, fa[N];
ll F[N], ans[N], tmp[N];

vector<int> G[N];
inline void add(int u, int v) {G[u].emplace_back(v), G[v].emplace_back(u);}

int rt, tot, sz[N], mx[N];
bool vis[N];
void gettot(int u, int fa) {tot++; for (int v : G[u]) if (v != fa && !vis[v]) gettot(v, u);}
void getrt(int u, int fa) {
    sz[u] = 1, mx[u] = 0;
    for (int v : G[u]) if (v != fa && !vis[v]) {
        getrt(v, u), sz[u] += sz[v], mx[u] = max(mx[u], sz[v]);
    }
    mx[u] = max(mx[u], tot - sz[u]);
    if (mx[u] < mx[rt]) rt = u;
}

int maxd, cntv[N];
void getd(int u, int fa, int dep) {
    maxd = max(maxd, dep), cntv[dep]++;
    for (int v : G[u]) if (v != fa && !vis[v]) getd(v, u, dep + 1);
}

ll cntu[N], sum[N], siu[N], f[SN][SN];
void solve(int u) {
    int H = vis[u] = cntu[0] = 1;
    for (int v : G[u]) if (!vis[v] && v != fa[u]) {
        maxd = 0; getd(v, u, 1); H = max(H, maxd);
        for (int i = 1; i <= maxd; i++) {
            ans[i] += cntv[i], cntu[i] += cntv[i];
            for (int j = 2 * i; j <= maxd; j += i) cntv[i] += cntv[j];
            sum[i] += cntv[i], siu[i] += 1ll * cntv[i] * cntv[i];
        }
        fill(cntv + 1, cntv + maxd + 1, 0);
    }
    for (int i = 1; i <= H; i++) F[i] += (sum[i] * sum[i] - siu[i]) / 2, sum[i] = siu[i] = 0;
    int sH = sqrt(H);
    for (int g = 1; g <= sH; g++) {
        for (int i = 0; i < g; i++) {
            f[g][i] = 0;
            for (int j = i; j <= H; j += g) f[g][i] += cntu[j];
        }
    }
    int h = 0;
    for (int v = fa[u], son = u; !vis[v]; son = v, v = fa[v]) {
        h++, maxd = 0;
        if (vis[son]) getd(v, fa[v], 0);
        else vis[son] = 1, getd(v, fa[v], 0), vis[son] = 0;
        for (int i = 1; i <= maxd; i++) {
            for (int j = 2 * i; j <= maxd; j += i) cntv[i] += cntv[j];
        }
        for (int i = 1; i <= maxd; i++) {
            if (i <= sH) F[i] += cntv[i] * f[i][(i - h % i) % i];
            else {
                ll c = 0;
                for (int j = (i - h % i) % i; j <= H; j += i) c += cntu[j];
                F[i] += cntv[i] * c;
            }
        }
        fill(cntv, cntv + maxd + 1, 0);
    }
    for (int i = 0; i <= H; i++) tmp[i + 1] += cntu[i], tmp[i + h + 1] -= cntu[i];
    for (int i = 1; i <= H + h; i++) tmp[i] += tmp[i - 1], ans[i] += tmp[i];
    fill(tmp + 1, tmp + H + h + 2, 0), fill(cntu + 1, cntu + H + 1, 0);
    for (int v : G[u]) if (!vis[v]) {
        rt = tot = 0; gettot(v, u); getrt(v, u); solve(rt);
    }
}

int main() {
    ios_base::sync_with_stdio(0); cin.tie(nullptr), cout.tie(nullptr);
    cin >> n;
    for (int i = 2; i <= n; i++) cin >> fa[i], add(fa[i], i);
    mx[0] = tot = n; getrt(1, 0), vis[0] = 1, solve(rt);
    for (int i = n - 1; i; i--) {
        for (int j = 2 * i; j < n; j += i) F[i] -= F[j];
        ans[i] += F[i];
    }
    for (int i = 1; i < n; i++) cout << ans[i] << '\n';
    return 0;
}
posted @ 2024-01-29 10:17  Chy12321  阅读(12)  评论(0编辑  收藏  举报