解题报告 『[国家集训队]Tree II(LCT)』

原题地址

LCT裸题,注意一下加法标记和乘法标记的下传顺序即可。

 

代码实现如下:

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define rep(i, a, b) for (register int i = (a); i <= (b); i++)

const int mod = 51061, maxn = 1e5 + 5;

int n, m;
int fa[maxn], add[maxn], mul[maxn], rev[maxn], sum[maxn], val[maxn], size[maxn], ch[maxn][3];

int get(int x) {return x == ch[fa[x]][1];}

void origin() {memset(mul, 1, sizeof(mul));}

int is_root(int x) {return x ^ ch[fa[x]][0] && x ^ ch[fa[x]][1];}
  
int read() {
    int x = 0, flag = 0;
    char ch = ' ';
    while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
    if (ch == '-') {
        flag = 1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + (ch ^ '0');
        ch = getchar();
    }
    return flag ? -x : x;
}

void clear(int x) {
    mul[x] = 1;
    fa[x] = add[x] = rev[x] = sum[x] = val[x] = size[x] = ch[x][0] = ch[x][1] = 0;
}
  
void maintain(int x) {
    clear(0);
    size[x] = (size[ch[x][0]] + size[ch[x][1]] + 1) % mod;
    sum[x] = (sum[ch[x][0]] + sum[ch[x][1]] + val[x]) % mod;
}
  
void push_down(int x) {
    clear(0);
    if (rev[x]) {
            ch[x][0] ^= ch[x][1] ^= ch[x][0] ^= ch[x][1];
            rev[ch[x][0]] ^= 1;
            rev[ch[x][1]] ^= 1;
            rev[x] = 0;
    }
    if (mul[x] != 1) {
        if (ch[x][0]) {
            mul[ch[x][0]] = (mul[x] * mul[ch[x][0]]) % mod,
            val[ch[x][0]] = (val[ch[x][0]] * mul[x]) % mod,
            sum[ch[x][0]] = (sum[ch[x][0]] * mul[x]) % mod,
            add[ch[x][0]] = (add[ch[x][0]] * mul[x]) % mod;
        }
        if (ch[x][1]) {
            mul[ch[x][1]] = (mul[x] * mul[ch[x][1]]) % mod,
            val[ch[x][1]] = (val[ch[x][1]] * mul[x]) % mod,
            sum[ch[x][1]] = (sum[ch[x][1]] * mul[x]) % mod,
            add[ch[x][1]] = (add[ch[x][1]] * mul[x]) % mod;
        }
        mul[x] = 1;
    }
    if (add[x]) {
        if (ch[x][0]) {
            add[ch[x][0]] = (add[ch[x][0]] + add[x]) % mod,
            val[ch[x][0]] = (val[ch[x][0]] + add[x]) % mod,
            sum[ch[x][0]] = (sum[ch[x][0]] + add[x] * size[ch[x][0]] % mod) % mod;
        }
        if (ch[x][1]) {
            add[ch[x][1]] = (add[ch[x][1]] + add[x]) % mod,
            val[ch[x][1]] = (val[ch[x][1]] + add[x]) % mod,
            sum[ch[x][1]] = (sum[ch[x][1]] + add[x] * size[ch[x][1]] % mod) % mod;
        }
        add[x] = 0;
    }
}
  
void update(int x) {
    if (!is_root(x)) update(fa[x]);
    push_down(x);
}

void rotate(int x) {
    int y = fa[x], z = fa[y], chk = get(x);
    if (!is_root(y)) ch[z][get(y)] = x;
    ch[y][chk] = ch[x][chk ^ 1];
    fa[ch[x][chk ^ 1]] = y;
    ch[x][chk ^ 1] = y;
    fa[y] = x;
    fa[x] = z;
    maintain(y);
    maintain(x);
    maintain(z);
}
  
void splay(int x) {
    update(x);
    for (register int f = fa[x]; f = fa[x], !is_root(x); rotate(x))
        if (!is_root(f)) rotate(get(x) == get(f) ? f : x);
}

void access(int x) {
    for (register int p = 0; x; p = x, x = fa[x]) {
        splay(x);
        ch[x][1] = p;
        maintain(x);
    }
}

void make_root(int x) {
    access(x);
    splay(x);
    rev[x] ^= 1;
    if (rev[x]) {
        ch[x][0] ^= ch[x][1] ^= ch[x][0] ^= ch[x][1];
        rev[ch[x][0]] ^= 1;
        rev[ch[x][1]] ^= 1;
        rev[x] = 0;
    }
}

int find(int x) {
    access(x);
    splay(x);
    while (ch[x][0]) {
        push_down(x);
        x = ch[x][0];
    }
    return x;
}

void split(int u, int v) {
    make_root(u);
    access(v);
    splay(v);
}

void link(int u, int v) {
    make_root(u);
    if (u ^ find(v)) fa[u] = v;
    else return;
}

void cut(int u, int v) {
    split(u, v);
    if (ch[v][0] == u && !ch[u][1])
        fa[u] = ch[v][0] = 0;
    else return;
}

void write(int x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

signed main() {
    origin();
    n = read(), m = read();
    rep(i, 1, n) val[i] = 1;
    rep(i, 1, n - 1) {
        int u, v;
        u = read(), v = read();
        link(u, v);
    }
    rep(i, 1, m) {
        int u, v;
        char opt[2];
        scanf("%s", opt), u = read(), v = read();
        switch(opt[0]) {
            case('+'): {
                int c;
                c = read();
                split(u, v);
                add[v] = (add[v] + c) % mod;
                val[v] = (val[v] + c) % mod;
                sum[v] = (sum[v] + size[v] * c) % mod;
                break;
            }
            case('-'): {
                int a, b;
                a = read(), b = read();
                cut(u, v);
                link(a, b);
                break;
            }
            case('*'): {
                int c;
                c = read();
                split(u, v);
                mul[v] = (mul[v] * c) % mod;
                val[v] = (val[v] * c) % mod;
                sum[v] = (sum[v] * c) % mod;
                break;
            }
            case('/'): {
                split(u, v);
                write(sum[v]);
                printf("\n");
                break;
            }
        }
    }
    return 0;
}
View Code
posted @ 2019-08-23 15:47  雲裏霧裏沙  阅读(162)  评论(0编辑  收藏  举报