LOJ2434. 「ZJOI2018」历史

题目


正解

比第一题清真多了……但还是很难想……

首先读完题目之后是个人都会想到\(LCT\)。题目的那个过程就是模拟access的过程。
然而如果仅仅从这个模型上考虑不会有什么思路。

把每个点分开考虑,分析一个国家\(x\)会怎样被其它国家(记为\(y\))打败:

  1. \(y\)\(x\)的子树内。此时\(y\)一定会占领\(x\)的全部领土。将这个贡献记在\(y\)上。
  2. \(y\)\(x\)的祖先,或者是祖先的另一棵子树内。那么此时\(y\)会占领\(x\)\(lca\)及以上的领土。将这个贡献记在\(lca\)上。

\(siz_x\)表示\(x\)的子树内所有节点的操作次数之和。
挂在某个节点\(x\)上的答案是什么?可以发现对于\(x\)的一个儿子\(y\)\(y\)的子树内操作无论如何排列都是等价的。于是我们有若干种操作(所有\(y\)内子树的操作和\(x\)的操作),将这些操作进行排列,最终使得相邻不同种的个数最大。
玩很久可以发现贡献为\(\min(\sum s_i-1,2(\sum s_i-s_{max}))\)
\(s_i\)表示第\(i\)种操作的个数,\(s_{max}=\max s_i\)

于是这样30分就有了,现在考虑如何动态维护这个东西。
观察到右边小于等于左边的条件,\(2s_{max}\leq \sum s_i+1\)
于是对\(x\)的一个儿子\(y\),如果\(2s_y\leq \sum s_i+1\),就记\(y\)为重儿子,否则为轻儿子。
可以发现重儿子最多会有一个,也可能没有重儿子。
我们要维护数据结构,可以对一条链上的\(s_x\)进行区间加,并且能维护轻重儿子关系。如果本身就是重儿子,子树加之后一定还是重儿子。于是我们只需要考虑当前修改的点到根路径上有哪些轻儿子,逐个对这些轻儿子进行修改。
用树链剖分的方法分析,这条链上的轻边的个数是\(O(\log \sum a_i)\)级别的。

这个数据结构用\(LCT\)就可以非常简单地实现。树剖也可以。


代码

using namespace std;
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cassert>
#define N 400010
#define ll long long
int n, m;
ll a[N];
struct EDGE {
    int to;
    EDGE *las;
} e[N * 2];
int ne;
EDGE *last[N];
void link(int u, int v) {
    e[ne] = { v, last[u] };
    last[u] = e + ne++;
}
int fa[N];
ll ans;
ll siz[N];
int hs[N];
void init(int x) {
    siz[x] = a[x];
    ll mx = a[x];
    for (EDGE *ei = last[x]; ei; ei = ei->las)
        if (ei->to != fa[x]) {
            fa[ei->to] = x;
            init(ei->to);
            siz[x] += siz[ei->to];
            mx = max(mx, siz[ei->to]);
        }
    ans += min(siz[x] - 1, 2 * (siz[x] - mx));
}
struct Node *null;
struct Node {
    Node *fa, *c[2];
    bool isr;
    ll tag, s;
    bool pf, ha;
    ll ans;
    void gettag(ll c) { tag += c, s += c; }
    void pd() {
        if (tag) {
            c[0]->gettag(tag);
            c[1]->gettag(tag);
            tag = 0;
        }
    }
    void push() {
        if (!isr)
            fa->push();
        pd();
    }
    void upd() { ha = (c[0]->ha && c[1]->ha && pf); }
    bool getson() { return fa->c[0] != this; }
    void rotate() {
        Node *y = fa, *z = y->fa;
        if (y->isr)
            y->isr = 0, isr = 1;
        else
            z->c[y->getson()] = this;
        int k = getson();
        fa = z;
        y->c[k] = c[!k];
        c[!k]->fa = y;
        c[!k] = y;
        y->fa = this;
        ha = y->ha, y->upd();
    }
    void splay() {
        push();
        while (!isr) {
            if (!fa->isr)
                fa->getson() != getson() ? rotate() : fa->rotate();
            rotate();
        }
    }
    Node *access() {
        Node *x = this, *y = null;
        for (; x != null; y = x, x = x->fa) {
            x->splay();
            x->c[1]->isr = 1;
            x->c[1] = y;
            y->isr = 0;
            x->upd();
        }
        return y;
    }
} d[N];
void dfs(Node *t) {
    if (t->ha == 1)
        return;
    t->pd();
    if (t->pf == 0) {
        int x = t - d;
        ans -= d[fa[x]].ans;
        d[fa[x]].push();  // can be faster
        ll sum = d[fa[x]].s;
        int y = hs[fa[x]];
        if (y)
            d[y].splay();
        if (2 * t->s >= sum + 1) {
            if (y)
                d[y].ha = d[y].pf = 0;
            t->pf = 1, hs[fa[x]] = x;
            d[fa[x]].ans = 2 * (sum - t->s);
        } else if (y && 2 * d[y].s >= sum + 1)
            d[fa[x]].ans = 2 * (sum - d[y].s);
        else {
            if (y)
                d[y].ha = d[y].pf = 0, hs[fa[x]] = 0;
            d[fa[x]].ans = min(sum - 1, 2 * (sum - a[fa[x]]));
        }
        ans += d[fa[x]].ans;
    }
    dfs(t->c[0]);
    dfs(t->c[1]);
    t->upd();
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; ++i) scanf("%lld", &a[i]);
    for (int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        link(u, v), link(v, u);
    }
    init(1);
    null = d;
    *null = { null, null, null, 0, 0, 0, 0, 1, 0 };
    d[1] = { null, null, null, 1, 0, siz[1], 1, 1 };
    for (int i = 2; i <= n; ++i) {
        d[i] = { &d[fa[i]], null, null, 1, 0, siz[i] };
        if (2 * siz[i] >= siz[fa[i]] + 1)
            d[i].ha = d[i].pf = 1, hs[fa[i]] = i;
    }
    for (int i = 1; i <= n; ++i)
        if (hs[i])
            d[i].ans = 2 * (siz[i] - siz[hs[i]]);
        else
            d[i].ans = min(siz[i] - 1, 2 * (siz[i] - a[i]));
    printf("%lld\n", ans);
    while (m--) {
        int x, y;
        scanf("%d%d", &x, &y);
        a[x] += y;
        d[x].access(), d[x].splay();
        d[x].gettag(y);
        dfs(&d[x]);
        ans -= d[x].ans;
        if (hs[x]) {
            int z = hs[x];
            d[z].splay();
            if (d[z].s * 2 >= d[x].s + 1)
                d[x].ans = 2 * (d[x].s - d[z].s);
            else {
                d[z].ha = d[z].pf = 0, hs[x] = 0;
                d[x].ans = min(d[x].s - 1, 2 * (d[x].s - a[x]));
            }
        } else
            d[x].ans = min(d[x].s - 1, 2 * (d[x].s - a[x]));
        ans += d[x].ans;
        printf("%lld\n", ans);
    }
    return 0;
}
posted @ 2020-07-08 17:17  jz_597  阅读(97)  评论(0编辑  收藏  举报