UOJ435 Simple Tree

更好的阅读体验

题意

有一棵有根树,根为 \(1\),点有点权.
现在有 \(m\) 次操作,操作有3种:

  • 1 x y w,将 \(x\)\(y\) 的路径上的点点权加上 \(w\) (其中 \(w=\pm 1\));
  • 2 x y,询问在 \(x\)\(y\) 的路径上有多少个点点权 \(>0\)
  • 3 x,询问在 \(x\) 的子树里的点有多少个点点权 \(>0\).

强制在线,\(n,m\le 10^5,-10^9\le\text{点权}\le10^9\)

题解

先考虑序列上的做法
问题比较复杂,我们用分块来处理

我们把每一块内部预先排好序,维护在排好序的块内下标最小的大于零点的下标 \(pts\) 以及对于块中每个数第一个严格大于和小于该数的位置 \(nxt\)\(lst\)

  • 查询的时候,对于整块,用右端点减去 \(pts\) 即可得到答案
    对于散块暴力统计
  • 修改的时候,对于整块,我们修改零点,然后尝试让 \(pts\)\(nxt\)\(lst\)
    对于散块,暴力修改并归并排序

把序列问题放在树上,一个显然的想法是树链剖分,\(\mathcal{O}(n\log n+q\sqrt{n}\log{n})\),不足以通过此题

其实在树链剖分的情况下 \(\sqrt{n}\) 的块长不是最优的
每次链的询问和修改,涉及的整块是 \(\mathcal{O}(\sqrt{n})\) 级别的,然而涉及的散块元素个数是 \(\mathcal{O}(\sqrt{n}\log{n})\) 级别的,因此块长 \(\sqrt{n\log n}\) 时取到最优复杂度 \(\mathcal{O}(n\log n+q\sqrt{n\log n})\)

考虑使用一个分块上树通用的trick,对每条链单独分块
下面证明链操作的时间复杂度是 \(\mathcal{O}(\sqrt{n})\)

考虑以一个点与其祖先为顶点的链,因为任意一条链都可以由两条这样的链组成
设这条链从上到下依次经过的重链分别为 \(lnk_0, lnk_1, ..., lnk_k\)

\[\text{len}(lnk_i)\le\text{subtreeSize}(\text{top}(lnk_i))\le\frac{n}{2^i} \]

其中 \(\text{top}(x)\) 表示链 \(x\) 的链顶

\[\begin{align*} T(n)&=\sum_{0\le i\le k}\sqrt{\text{len}(lnk_i)}\\ &\le\sum_{0\le i\le k}\sqrt{\frac{n}{2^i}}\\ &=\sqrt{n}\sum_{0\le i\le k}\frac{1}{\sqrt{2}}^i\\ &=\Theta(\sqrt{n}) \end{align*}\]

即复杂度为 \(\mathcal{O}(\sqrt{n})\)

下面考虑子树查询

我们把每条重链按链顶dfn序排序,发现每棵子树都是由一条重链的一部分和排序后连续的完整重链构成
修改时,对于 \(\mathcal{O}(\log n)\) 条被修改的重链,在树状数组上更新答案
子树询问时,区间查询即可

总复杂度为 \(\mathcal{O}(n\log n+q\sqrt{n}+q\log^2 n)\)

细节

散块常数大,调一调参跑得快

代码

#include <bits/stdc++.h>
using namespace std;
const int maxn = 100000, maxs = 200;
int n, q, T;
int val[maxn + 5];
vector<int> g[maxn + 5];
int stamp;
int dep[maxn + 5];
int fa[maxn + 5];
int heavy_son[maxn + 5];
int siz[maxn + 5];
int tp[maxn + 5], bt[maxn + 5];
int id[maxn + 5];
int lnk[maxn + 5];
int lcnt;
int lid[maxn + 5];
int minl[maxn + 5], maxl[maxn + 5];
void dfs1(int u, int f) {
    fa[u] = f;
    dep[u] = dep[f] + 1;
    siz[u] = 1;
    for (size_t i = 0; i < g[u].size(); i++) {
        int v = g[u][i];
        if (v == f)
            continue;
        dfs1(v, u);
        siz[u] += siz[v];
        if (siz[v] > siz[heavy_son[u]])
            heavy_son[u] = v;
    }
}
void dfs2(int u, int f, int top) {
    id[u] = ++stamp;
    lnk[u] = top;
    if (u == top)
        tp[++lcnt] = id[u];
    if (heavy_son[u] == 0)
        bt[lcnt] = id[u];
    lid[u] = lcnt;
    minl[u] = lcnt + 1;
    if (heavy_son[u])
        dfs2(heavy_son[u], u, top);
    for (size_t i = 0; i < g[u].size(); i++) {
        int v = g[u][i];
        if (v == f || v == heavy_son[u])
            continue;
        dfs2(v, u, v);
    }
    maxl[u] = lcnt;
}
inline int lowbit(int x) {
    return x & (-x);
}
struct fenwic {
    int t[maxn + 5];
    void modify(int x, int y) {
        while (x <= lcnt) {
            t[x] += y;
            x += lowbit(x);
        }
    }
    int query_(int x) const {
        int res = 0;
        while (x > 0) {
            res += t[x];
            x -= lowbit(x);
        }
        return res;
    }
    int query(int l, int r) const {
        return query_(r) - query_(l - 1);
    }
} bit;
struct block {
    int cnt;
    int tot[maxn + 5];
    int a[maxn + 5];
    int blg[maxn + 5];
    int zero[maxn + 5];
    int lp[maxn + 5], rp[maxn + 5], pts[maxn + 5];
    pair<int, int> tmp1[maxs + 5], tmp2[maxs + 5];
    pair<int, int> sorted[maxn + 5];
    int nxt[maxn + 5], lst[maxn + 5], lpts[maxn + 5];
    void resize(int arr[], int l, int r, int lk) {
        int len = r - l + 1;
        int bsiz = ceil(sqrt(.12 * len));
        int bcnt = ceil(1. * len / bsiz);
        for (int i = l; i <= r; i++) {
            a[i] = arr[i];
            sorted[i] = pair<int, int>(arr[i], i);
        }
        for (int i = cnt + 1; i <= cnt + bcnt; i++) {
            lp[i] = l + (i - cnt - 1) * bsiz;
            rp[i] = min(r, l + (i - cnt) * bsiz - 1);
            for (int j = lp[i]; j <= rp[i]; j++)
                blg[j] = i;
            sort(sorted + lp[i], sorted + rp[i] + 1);
            calcNxtLst(i);
            tot[lk] += rp[i] - pts[i] + 1;
        }
        cnt += bcnt;
    }
    void calcNxtLst(int bid) {
        for (int i = rp[bid], j = rp[bid] + 1; i >= lp[bid]; i--) {
            while (i > lp[bid] && sorted[i].first == sorted[i - 1].first)
                i--;
            nxt[i] = j;
            if (j <= rp[bid])
                lst[j] = i;
            else
                lpts[bid] = i;
            j = i;
        }
        lst[lp[bid]] = lp[bid];
        pts[bid] = rp[bid] + 1;
        for (int i = lp[bid]; i <= rp[bid]; i++) {
            if (sorted[i].first > zero[bid]) {
                pts[bid] = i;
                break;
            }
        }
    }
    void mergeSort(int bid, int l, int r, int delta, int lk) {
        tot[lk] += pts[bid];
        for (int i = l; i <= r; i++)
            a[i] += delta;
        int cnt1 = 0, cnt2 = 0;
        for (int i = lp[bid]; i <= rp[bid]; i++) {
            if (sorted[i].second >= l && sorted[i].second <= r)
                tmp1[++cnt1] =
                    pair<int, int>(sorted[i].first + delta, sorted[i].second);
            else
                tmp2[++cnt2] = sorted[i];
        }
        merge(tmp1 + 1, tmp1 + cnt1 + 1, tmp2 + 1, tmp2 + cnt2 + 1,
              sorted + lp[bid]);
        calcNxtLst(bid);
        tot[lk] -= pts[bid];
    }
    void modify(int l, int r, int delta, int lk) {
        int lb = blg[l], rb = blg[r];
        if (lb == rb)
            mergeSort(lb, l, r, delta, lk);
        else {
            mergeSort(lb, l, rp[lb], delta, lk);
            mergeSort(rb, lp[rb], r, delta, lk);
            for (int i = lb + 1; i < rb; i++) {
                tot[lk] += pts[i];
                zero[i] -= delta;
                if (pts[i] > rp[i]) {
                    if (sorted[lpts[i]].first > zero[i])
                        pts[i] = lpts[i];
                } else {
                    if (delta == 1) {
                        if (sorted[lst[pts[i]]].first > zero[i])
                            pts[i] = lst[pts[i]];
                    } else if (sorted[pts[i]].first <= zero[i])
                        pts[i] = nxt[pts[i]];
                }
                tot[lk] -= pts[i];
            }
        }
    }
    int query(int l, int r) const {
        int res = 0;
        int lb = blg[l], rb = blg[r];
        if (lb == rb) {
            for (int i = l; i <= r; i++)
                if (a[i] > zero[lb])
                    res++;
        } else {
            for (int i = l; i <= rp[lb]; i++)
                if (a[i] > zero[lb])
                    res++;
            for (int i = lp[rb]; i <= r; i++)
                if (a[i] > zero[rb])
                    res++;
            for (int i = lb + 1; i < rb; i++)
                res += rp[i] - pts[i] + 1;
        }
        return res;
    }
} bl;
int queryLnk(int x, int y) {
    int res = 0;
    while (lnk[x] != lnk[y]) {
        if (dep[lnk[x]] < dep[lnk[y]])
            swap(x, y);
        res += bl.query(tp[lid[x]], id[x]);
        x = fa[lnk[x]];
    }
    if (dep[y] < dep[x])
        swap(x, y);
    res += bl.query(id[x], id[y]);
    return res;
}
int querySubtree(int x) {
    return bl.query(id[x], bt[lid[x]]) + bit.query(minl[x], maxl[x]);
}
void modifyLnk(int x, int y, int delta) {
    while (lnk[x] != lnk[y]) {
        if (dep[lnk[x]] < dep[lnk[y]])
            swap(x, y);
        int ori = bl.tot[lid[x]];
        bl.modify(tp[lid[x]], id[x], delta, lid[x]);
        bit.modify(lid[x], bl.tot[lid[x]] - ori);
        x = fa[lnk[x]];
    }
    if (dep[y] < dep[x])
        swap(x, y);
    int ori = bl.tot[lid[x]];
    bl.modify(id[x], id[y], delta, lid[x]);
    bit.modify(lid[x], bl.tot[lid[x]] - ori);
}
int main() {
    scanf("%d%d%d", &n, &q, &T);
    for (int i = 1; i < n; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs1(1, 0);
    dfs2(1, 0, 1);
    for (int i = 1; i <= n; i++)
        scanf("%d", &val[id[i]]);
    for (int i = 1; i <= lcnt; i++) {
        bl.resize(val, tp[i], bt[i], i);
        bit.modify(i, bl.tot[i]);
    }
    int lstans = 0;
    for (int i = 1; i <= q; i++) {
        int op, x, y, w;
        scanf("%d%d", &op, &x);
        if (T == 1)
            x ^= lstans;
        if (op == 1) {
            scanf("%d%d", &y, &w);
            if (T == 1)
                y ^= lstans;
            modifyLnk(x, y, w);
        } else if (op == 2) {
            scanf("%d", &y);
            if (T == 1)
                y ^= lstans;
            printf("%d\n", lstans = queryLnk(x, y));
        } else
            printf("%d\n", lstans = querySubtree(x));
    }
}
posted @ 2022-02-01 11:11  gzezFISHER  阅读(31)  评论(0编辑  收藏  举报