解题报告 『[国家集训队]Tree II(LCT)』
LCT裸题,注意一下加法标记和乘法标记的下传顺序即可。
代码实现如下:
#include <bits/stdc++.h> using namespace std; #define int long long #define rep(i, a, b) for (register int i = (a); i <= (b); i++) const int mod = 51061, maxn = 1e5 + 5; int n, m; int fa[maxn], add[maxn], mul[maxn], rev[maxn], sum[maxn], val[maxn], size[maxn], ch[maxn][3]; int get(int x) {return x == ch[fa[x]][1];} void origin() {memset(mul, 1, sizeof(mul));} int is_root(int x) {return x ^ ch[fa[x]][0] && x ^ ch[fa[x]][1];} int read() { int x = 0, flag = 0; char ch = ' '; while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar(); if (ch == '-') { flag = 1; ch = getchar(); } while (ch >= '0' && ch <= '9') { x = (x << 1) + (x << 3) + (ch ^ '0'); ch = getchar(); } return flag ? -x : x; } void clear(int x) { mul[x] = 1; fa[x] = add[x] = rev[x] = sum[x] = val[x] = size[x] = ch[x][0] = ch[x][1] = 0; } void maintain(int x) { clear(0); size[x] = (size[ch[x][0]] + size[ch[x][1]] + 1) % mod; sum[x] = (sum[ch[x][0]] + sum[ch[x][1]] + val[x]) % mod; } void push_down(int x) { clear(0); if (rev[x]) { ch[x][0] ^= ch[x][1] ^= ch[x][0] ^= ch[x][1]; rev[ch[x][0]] ^= 1; rev[ch[x][1]] ^= 1; rev[x] = 0; } if (mul[x] != 1) { if (ch[x][0]) { mul[ch[x][0]] = (mul[x] * mul[ch[x][0]]) % mod, val[ch[x][0]] = (val[ch[x][0]] * mul[x]) % mod, sum[ch[x][0]] = (sum[ch[x][0]] * mul[x]) % mod, add[ch[x][0]] = (add[ch[x][0]] * mul[x]) % mod; } if (ch[x][1]) { mul[ch[x][1]] = (mul[x] * mul[ch[x][1]]) % mod, val[ch[x][1]] = (val[ch[x][1]] * mul[x]) % mod, sum[ch[x][1]] = (sum[ch[x][1]] * mul[x]) % mod, add[ch[x][1]] = (add[ch[x][1]] * mul[x]) % mod; } mul[x] = 1; } if (add[x]) { if (ch[x][0]) { add[ch[x][0]] = (add[ch[x][0]] + add[x]) % mod, val[ch[x][0]] = (val[ch[x][0]] + add[x]) % mod, sum[ch[x][0]] = (sum[ch[x][0]] + add[x] * size[ch[x][0]] % mod) % mod; } if (ch[x][1]) { add[ch[x][1]] = (add[ch[x][1]] + add[x]) % mod, val[ch[x][1]] = (val[ch[x][1]] + add[x]) % mod, sum[ch[x][1]] = (sum[ch[x][1]] + add[x] * size[ch[x][1]] % mod) % mod; } add[x] = 0; } } void update(int x) { if (!is_root(x)) update(fa[x]); push_down(x); } void rotate(int x) { int y = fa[x], z = fa[y], chk = get(x); if (!is_root(y)) ch[z][get(y)] = x; ch[y][chk] = ch[x][chk ^ 1]; fa[ch[x][chk ^ 1]] = y; ch[x][chk ^ 1] = y; fa[y] = x; fa[x] = z; maintain(y); maintain(x); maintain(z); } void splay(int x) { update(x); for (register int f = fa[x]; f = fa[x], !is_root(x); rotate(x)) if (!is_root(f)) rotate(get(x) == get(f) ? f : x); } void access(int x) { for (register int p = 0; x; p = x, x = fa[x]) { splay(x); ch[x][1] = p; maintain(x); } } void make_root(int x) { access(x); splay(x); rev[x] ^= 1; if (rev[x]) { ch[x][0] ^= ch[x][1] ^= ch[x][0] ^= ch[x][1]; rev[ch[x][0]] ^= 1; rev[ch[x][1]] ^= 1; rev[x] = 0; } } int find(int x) { access(x); splay(x); while (ch[x][0]) { push_down(x); x = ch[x][0]; } return x; } void split(int u, int v) { make_root(u); access(v); splay(v); } void link(int u, int v) { make_root(u); if (u ^ find(v)) fa[u] = v; else return; } void cut(int u, int v) { split(u, v); if (ch[v][0] == u && !ch[u][1]) fa[u] = ch[v][0] = 0; else return; } void write(int x) { if (x < 0) { putchar('-'); x = -x; } if (x > 9) write(x / 10); putchar(x % 10 + '0'); } signed main() { origin(); n = read(), m = read(); rep(i, 1, n) val[i] = 1; rep(i, 1, n - 1) { int u, v; u = read(), v = read(); link(u, v); } rep(i, 1, m) { int u, v; char opt[2]; scanf("%s", opt), u = read(), v = read(); switch(opt[0]) { case('+'): { int c; c = read(); split(u, v); add[v] = (add[v] + c) % mod; val[v] = (val[v] + c) % mod; sum[v] = (sum[v] + size[v] * c) % mod; break; } case('-'): { int a, b; a = read(), b = read(); cut(u, v); link(a, b); break; } case('*'): { int c; c = read(); split(u, v); mul[v] = (mul[v] * c) % mod; val[v] = (val[v] * c) % mod; sum[v] = (sum[v] * c) % mod; break; } case('/'): { split(u, v); write(sum[v]); printf("\n"); break; } } } return 0; }