[ZJOI2019]语言——树剖+树上差分+线段树合并

原题链接戳这儿

SOLUTION

考虑一种非常\(naive\)的统计方法,就是对于每一个点\(u\),我们维护它能到达的点集\(S_u\),最后答案就是\(\frac{\sum\limits_{i=1}^{n}|S_i|}{2}\)
也就是说我们可以先树剖一下,对于每一个点都开一棵线段树,每次修改\(O(nlogn)\)地更新一下路径上的线段树,最后查询一下就行了
但是这样的复杂度是\(O(n^2log^2n)\)的,显然会炸。注意到每次是对一条链上的所有点操作,所以我们可以查分。又因为差分之后要把子树的贡献传上去,再上个线段树合并就行了,复杂度降为\(O(nlog^2n)\)
代码也比较好写,细节不多

#include <bits/stdc++.h>

using namespace std;

#define N 100000
#define ll long long
#define mp make_pair
#define pii pair<int, int>
#define pb push_back
#define mid ((l + r) >> 1)

int n, m;
vector<int> G[N + 5];
int sz[N + 5], fa[N + 5], d[N + 5], hson[N + 5], top[N + 5], dfn[N + 5], dfn_clk, id[N + 5];
int nid, root[N + 5], sumv[N << 7], ch[2][N << 7], addv[N << 7];
vector<pii> cf[N + 5], segs[N + 5];
ll ans;

void dfs1(int u, int pa) {
    sz[u] = 1;
    fa[u] = pa;
    for (int i = 0, v; i < G[u].size(); ++i) {
        v = G[u][i];
        if (v == pa) continue;
        d[v] = d[u] + 1;
        dfs1(v, u);
        sz[u] += sz[v];
        if (sz[v] > sz[hson[u]]) hson[u] = v;
    }
}

void dfs2(int u, int tp) {
    dfn[u] = ++dfn_clk;
    id[dfn_clk] = u;
    top[u] = tp;
    if (hson[u]) dfs2(hson[u], tp);
    for (int i = 0, v; i < G[u].size(); ++i) {
        v = G[u][i];
        if (v == fa[u] || v == hson[u]) continue;
        dfs2(v, v);
    }
}

int lca(int x, int y) {
    while (top[x] != top[y]) d[top[x]] > d[top[y]] ? x = fa[top[x]] : y = fa[top[y]];
    return d[x] < d[y] ? x : y;
}

void addModify(int x, int s, int t) {
    int z = lca(s, t);
    cf[s].pb(mp(x, 1)), cf[t].pb(mp(x, 1));
    cf[z].pb(mp(x, -1)), cf[fa[z]].pb(mp(x, -1));
    while (top[s] != top[z]) segs[x].pb(mp(dfn[top[s]], dfn[s])), s = fa[top[s]];
    segs[x].pb(mp(dfn[z], dfn[s]));
    while (top[t] != top[z]) segs[x].pb(mp(dfn[top[t]], dfn[t])), t = fa[top[t]];
    if (dfn[z] + 1 <= dfn[t]) segs[x].pb(mp(dfn[z] + 1, dfn[t]));
}

void pushup(int o, int l, int r) {
    if (addv[o]) sumv[o] = r - l + 1;
    else sumv[o] = sumv[ch[0][o]] + sumv[ch[1][o]];
}

void add(int &o, int l, int r, int L, int R, int k) {
    if (!o) o = ++nid;
    if (L <= l && r <= R) {
        addv[o] += k;
        pushup(o, l, r);
        return ;
    }
    if (L <= mid) add(ch[0][o], l, mid, L, R, k);
    if (R > mid) add(ch[1][o], mid + 1, r, L, R, k);
    pushup(o, l, r);
}

void merge(int &o, int u, int l, int r) {
    if (!o || !u) {
        if (!o) o = u;
        return ;
    }
    addv[o] += addv[u];
    if (l < r) {
        merge(ch[0][o], ch[0][u], l, mid);
        merge(ch[1][o], ch[1][u], mid + 1, r);
    }
    pushup(o, l, r);
}

void dfs(int u) {
    for (int i = 0, v; i < G[u].size(); ++i) {
        v = G[u][i];
        if (v == fa[u]) continue;
        dfs(v);
        merge(root[u], root[v], 1, n);
    }
    for (int i = 0; i < cf[u].size(); ++i)
        for (int j = 0; j < segs[cf[u][i].first].size(); ++j)
            add(root[u], 1, n, segs[cf[u][i].first][j].first, segs[cf[u][i].first][j].second, cf[u][i].second);
    ans += max(0, sumv[root[u]] - 1); // 注意这里要与0取max
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1, x, y; i < n; ++i) {
        scanf("%d%d", &x, &y);
        G[x].pb(y), G[y].pb(x);
    }
    dfs1(1, 0), dfs2(1, 1);
    for (int i = 1, s, t; i <= m; ++i) {
        scanf("%d%d", &s, &t);
        addModify(i, s, t);
    }
    dfs(1);
    ans /= 2;
    printf("%lld\n", ans);
    return 0;
}
posted @ 2019-07-04 12:42  dummyummy  阅读(247)  评论(0编辑  收藏  举报