NOI2021 轻重边

这是一个有点邪教的两 \(\log\) 做法,并且常数大,过题需要卡。

看到操作和题目名称,不难想到树剖。

不妨设以 \(1\) 为根,一条边的贡献由其深度较大的那个端点计算。难点在于如何安排 dfn 序,使得一条链和与一条链上的点直接相连的点的 dfn 是连续的。重链剖分后对于每条重链处理 dfn 序。考虑把处理 dfn 序的 dfs 改为 bfs,用队列记录待处理重链的顶端结点。首先让 \(1\) 入队,对于队首 \(u\)首先为 \(u\) 所在的重链安排连续的 dfn 序,然后为这条重链上所有点的所有轻儿子安排连续的 dfn 序,再将这条重链上所有点的所有轻儿子入队。发现重链顶端结点的 dfn 序和其他结点并不连续,不过不影响复杂度。

于是修改和查询的时候拆成两条路径就可以跳重链做了,需要注意的是修改第二条路径会对影响第一条的修改(第一条路径中 LCA 的儿子代表的边会被修改为轻边),暴力改回来即可。

时间复杂度 \(\mathcal{O}(n \log^2 n)\),实现起来较为麻烦,且需要卡常(例如我的代码中线段树区间修改和单点修改拆开了)。

#include <bits/stdc++.h>

using namespace std;

#define il inline
#define re register
#define rep(i, s, e) for (re int i = s; i <= e; ++i)
#define drep(i, s, e) for (re int i = s; i >= e; --i)
#define file(a) freopen(#a".in", "r", stdin), freopen(#a".out", "w", stdout)

const int N = 1000000 + 10;

il int read() {
    int x = 0; bool f = true; char c = getchar();
    while (!isdigit(c)) {if (c == '-') f = false; c = getchar();}
    while (isdigit(c)) x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
    return f ? x : -x;
}

int n, m;
vector <int> e[N];

int dat[N << 2], tag[N << 2];

#define ls (u << 1)
#define rs (u << 1 | 1)
#define mid (l + r >> 1)

void reset(int u, int l, int r) {
    dat[u] = 0, tag[u] = -1;
    if (l == r) return;
    reset(ls, l, mid), reset(rs, mid + 1, r);
}

il void pushdown(int u, int l, int r) {
    if (tag[u] == -1) return;
    dat[ls] = tag[u] * (mid - l + 1), dat[rs] = tag[u] * (r - mid);
    tag[ls] = tag[u], tag[rs] = tag[u], tag[u] = -1;
}

void modify(int ml, int mr, int k, int u, int l, int r) {
    if (ml <= l && r <= mr) { dat[u] = k * (r - l + 1), tag[u] = k; return; }
    pushdown(u, l, r);
    if (ml <= mid) modify(ml, mr, k, ls, l, mid);
    if (mr > mid) modify(ml, mr, k, rs, mid + 1, r);
    dat[u] = dat[ls] + dat[rs];
}

void single_modify(int p, int k, int u, int l, int r) {
    if (l == r) { dat[u] = tag[u] = k; return; }
    pushdown(u, l, r);
    if (p <= mid) single_modify(p, k, ls, l, mid);
    else single_modify(p, k, rs, mid + 1, r);
    dat[u] = dat[ls] + dat[rs];
}

int query(int ql, int qr, int u, int l, int r) {
    if (ql <= l && r <= qr) return dat[u];
    pushdown(u, l, r);
    int res = 0;
    if (ql <= mid) res += query(ql, qr, ls, l, mid);
    if (qr > mid) res += query(ql, qr, rs, mid + 1, r);
    return res;
}

int sz[N], mson[N], dep[N], fr[N];
void dfs1(int u, int fa) {
    sz[u] = 1, mson[u] = 0, fr[u] = fa;
    for (int v : e[u]) {
        if (v == fa) continue;
        dep[v] = dep[u] + 1, dfs1(v, u), sz[u] += sz[v];
        if (sz[v] > sz[mson[u]]) mson[u] = v;
    }
}

int top[N], dfn[N], dcnt, lef[N], rig[N];
void dfs2(int u, int fa, int tp) {
    top[u] = tp;
    if (mson[u]) dfs2(mson[u], u, tp);
    for (int v : e[u]) {
        if (v != fa && v != mson[u]) dfs2(v, u, v);
    }
}

void make_dfn() {
    queue <int> q;
    q.push(1);
    while (q.size()) {
        int u = q.front(); q.pop();
        for (int v = mson[u]; v; v = mson[v]) dfn[v] = ++ dcnt;
        for (int v = u; v; v = mson[v]) {
            lef[v] = dcnt + 1;
            for (int w : e[v]) if (w != fr[v] && w != mson[v]) dfn[w] = ++ dcnt, q.push(w);
            rig[v] = dcnt;
        }
    }
}

int LCA(int u, int v) {
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        u = fr[top[u]];
    }
    return (dep[u] < dep[v]) ? u : v;
}

void doit(int u, int lca) {
    int v = 0;
    while (dep[lca] < dep[top[u]]) {
        if (mson[u]) single_modify(dfn[mson[u]], 0, 1, 1, n);
        if (lef[top[u]] <= rig[u]) modify(lef[top[u]], rig[u], 0, 1, 1, n);
        if (u != top[u]) modify(dfn[mson[top[u]]], dfn[u], 1, 1, 1, n);
        if (v) single_modify(dfn[v], 1, 1, 1, n);
        u = top[u], v = u, single_modify(dfn[u], 1, 1, 1, n), u = fr[u];
    }
    if (dfn[lca]) single_modify(dfn[lca], 0, 1, 1, n);
    if (mson[u]) single_modify(dfn[mson[u]], 0, 1, 1, n);
    if (lef[lca] <= rig[u]) modify(lef[lca], rig[u], 0, 1, 1, n);
    if (u != lca) modify(dfn[mson[lca]], dfn[u], 1, 1, 1, n);
    if (v) single_modify(dfn[v], 1, 1, 1, n);
}

void change(int u, int v) {
    int lca = LCA(u, v);
    if (u == lca) doit(v, u);
    else if (v == lca) doit(u, v);
    else {
        doit(u, lca), doit(v, lca);
        int w;
        for (w = u; dep[top[w]] > dep[lca]; w = fr[top[w]]) {
            w = top[w];
            if (fr[w] == lca) single_modify(dfn[w], 1, 1, 1, n);
        }
        if (w != lca) single_modify(dfn[mson[lca]], 1, 1, 1, n);
    }
}

int askit(int u, int lca) {
    int res = 0;
    while (dep[lca] < dep[top[u]]) {
        if (u != top[u]) res += query(dfn[mson[top[u]]], dfn[u], 1, 1, n);
        u = top[u], res += query(dfn[u], dfn[u], 1, 1, n), u = fr[u];
    }
    if (u != lca) res += query(dfn[mson[lca]], dfn[u], 1, 1, n);
    return res;
}

int ask(int u, int v) {
    int lca = LCA(u, v);
    return askit(u, lca) + askit(v, lca);
}

int main() {
    int tc = read();
    while (tc --) {
        n = read(), m = read(), dcnt = 0;
        reset(1, 1, n);
        rep(i, 1, n) e[i].clear();
        rep(i, 1, n - 1) {
            int u = read(), v = read();
            e[u].push_back(v), e[v].push_back(u);
        }
        dfs1(1, 0);
        dfs2(1, 0, 1);
        make_dfn();
        while (m --) {
            int op = read(), u = read(), v = read();
            if (op == 1) change(u, v);
            else printf("%d\n", ask(u, v));
        }
    }
    return 0;
}
posted @ 2021-07-26 19:25  Scintilla06  阅读(77)  评论(0编辑  收藏  举报