[luoguP4074/WC2013] 糖果公园

题意

给你一棵树,每个点有个颜色
每次询问你一条路径求 $$\sum_{c}val_c\sum_{i=1}^{cnt_c}worth_i$$其中\(val\)表示该颜色的价值,\(cnt\)表示其出现的次数,\(worth_i\) 表示第 \(i\) 次出现的价值,带修改。

sol

显然莫队,但是是一道树上莫队和带修莫队的结合题,考验码力。
树上莫队:[luoguSP10707] Count on a tree II
带修莫队:[luoguP1903] 数颜色
这里注意,树上带修时,强制修改颜色在交换前后各计算一次即可,无需(也无法)直接传入颜色,且仅当待修改点被标记已被加入答案才被修改(而不是普通带修莫队的判断区间)

代码

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>

#define x first 
#define y second 

using namespace std;
typedef pair<int, int> PII;
typedef long long LL;

const int N = 100005, M = 2 * N;

int h[N], e[M], ne[M], idx;
int c[N], buc[N];
int v[N], w[N];
int block[2 * N], eular[2 * N], rk1[N], rk2[N], timestamp;
bool st[N];
LL res = 0, ans[N];
int fa[N][18], dep[N];
PII edit[N];
int n, m, q;

struct Query {
    int l, r, time, lca, id;
    bool operator< (const Query &W) const {
        if (block[l] != block[W.l]) return block[l] < block[W.l];
        if (block[r] != block[W.r]) return block[l] & 1 ? block[r] < block[W.r] : block[r] > block[W.r];
        return block[r] & 1 ? time < W.time : time > W.time;
    }
} queries[N];

void add(int a, int b) {
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}

void dfs_init(int u, int father){
    dep[u] = dep[father] + 1, fa[u][0] = father;
    eular[ ++ timestamp] = u, rk1[u] = timestamp;

    for (int i = h[u]; ~i; i = ne[i]){
        int j = e[i];
        if (j == father) continue;
        dfs_init(j, u);
    }

    eular[ ++ timestamp] = u, rk2[u] = timestamp;
}

int lca(int x, int y){
    if (dep[x] < dep[y]) swap(x, y);

    for (int i = 17; i >= 0; i -- ){
        int fax = fa[x][i];
        if (dep[fax] >= dep[y]) x = fax;
    }

    if (x == y) return x;

    for (int i = 17; i >= 0; i -- ){
        int fax = fa[x][i], fay = fa[y][i];
        if (fax != fay) x = fax, y = fay;
    }

    return fa[x][0];
}

void calc(int x){
    st[x] = !st[x];
    if (st[x]) {
        buc[c[x]] ++ ;
        res += (LL) v[c[x]] * w[buc[c[x]]];
    }
    else {
        res -= (LL) v[c[x]] * w[buc[c[x]]];
        buc[c[x]] -- ;
    }
}

int main(){
    memset(h, -1, sizeof h);

    scanf("%d%d%d", &n, &m, &q);
    for (int i = 1; i <= m; i ++ ) scanf("%d", &v[i]);
    for (int i = 1; i <= n; i ++ ) scanf("%d", &w[i]);
    for (int i = 1; i < n; i ++ ){
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b), add(b, a);
    }

    dfs_init(1, 0);
    for (int k = 1; k < 18; k ++ )
        for (int i = 1; i <= n; i ++ )
            fa[i][k] = fa[fa[i][k - 1]][k - 1];
    
    for (int i = 1; i <= n; i ++ ) scanf("%d", &c[i]);
    int qcnt = 0, ecnt = 0;
    while (q -- ){
        int op, x, y;
        scanf("%d%d%d", &op, &x, &y);
        if (op == 0) edit[ ++ ecnt] = {x, y};
        else {
            qcnt ++ ;
            if (rk1[x] > rk1[y]) swap(x, y);
            int l = lca(x, y);
            if (l == x) queries[qcnt] = {rk1[x], rk1[y], ecnt, 0, qcnt};
            else queries[qcnt] = {rk2[x], rk1[y], ecnt, l, qcnt};
        }
    }

    int Sz = pow(2 * n, 0.667);
    int bcnt = ceil(2.0 * n / Sz);
    for (int i = 1; i <= bcnt; i ++ )
        for (int j = (i - 1) * Sz + 1; j <= min(2 * n, i * Sz); j ++ )
            block[j] = i;
    
    sort(queries + 1, queries + qcnt + 1);

    int l = 1, r = 0, time = 0;
    for (int i = 1; i <= qcnt; i ++ ){
        Query &q = queries[i];
        while (l > q.l) calc(eular[ -- l]);
        while (r < q.r) calc(eular[ ++ r]);
        while (l < q.l) calc(eular[l ++ ]);
        while (r > q.r) calc(eular[r -- ]);
        while (time < q.time) {
            time ++ ;
            if (st[edit[time].x]) {
                calc(edit[time].x);
                swap(edit[time].y, c[edit[time].x]);
                calc(edit[time].x);
            }
            else swap(edit[time].y, c[edit[time].x]);
        }
        while (time > q.time) {
            if (st[edit[time].x]) {
                calc(edit[time].x);
                swap(edit[time].y, c[edit[time].x]);
                calc(edit[time].x);
            }
            else swap(edit[time].y, c[edit[time].x]);
            time -- ;
        }
        if (q.lca) calc(q.lca);
        ans[q.id] = res;
        if (q.lca) calc(q.lca);
    }

    for (int i = 1; i <= qcnt; i ++ ) printf("%lld\n", ans[i]);

    return 0;
}
posted @ 2024-11-28 21:55  是一只小蒟蒻呀  阅读(3)  评论(0编辑  收藏  举报