【模板】多标记 LCT
代码如下
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int mod = 51061;
struct node {
node *l, *r, *p;
int rev, val;
LL sum, add, mul, sz;
node() {
l = r = p = NULL;
sum = add = rev = 0;
mul = val = sz = 1;
}
void unsafe_reverse() {
swap(l, r);
rev ^= 1;
}
void unsafe_add(LL x) {
sum = (sum + x * sz) % mod;
add = (add + x) % mod;
val = (val + x) % mod;
}
void unsafe_mul(LL x) {
sum = sum * x % mod;
add = add * x % mod;
val = val * x % mod;
mul = mul * x % mod;
}
void pull() {
sum = val;
sz = 1;
if (l != NULL) {
l->p = this;
sum = (sum + l->sum) % mod;
sz += l->sz;
}
if (r != NULL) {
r->p = this;
sum = (sum + r->sum) % mod;
sz += r->sz;
}
}
void push() {
if (rev) {
if (l != NULL) {
l->unsafe_reverse();
}
if (r != NULL) {
r->unsafe_reverse();
}
rev = 0;
}
if (mul != 1) {
if (l != NULL) {
l->unsafe_mul(mul);
}
if (r != NULL) {
r->unsafe_mul(mul);
}
mul = 1;
}
if (add != 0) {
if (l != NULL) {
l->unsafe_add(add);
}
if (r != NULL) {
r->unsafe_add(add);
}
add = 0;
}
}
};
bool is_root(node *v) {
if (v == NULL) {
return false;
}
return (v->p == NULL) || (v->p->l != v && v->p->r != v);
}
void rotate(node *v) {
node *u = v->p;
assert(u != NULL);
v->p = u->p;
if (v->p != NULL) {
if (v->p->l == u) {
v->p->l = v;
}
if (v->p->r == u) {
v->p->r = v;
}
}
if (v == u->l) {
u->l = v->r;
v->r = u;
}
if (v == u->r) {
u->r = v->l;
v->l = u;
}
u->pull();
v->pull();
}
void deal_with_push(node *v) {
static stack<node*> s;
while (1) {
s.push(v);
if (is_root(v)) {
break;
}
v = v->p;
}
while (!s.empty()) {
s.top()->push();
s.pop();
}
}
void splay(node *v) {
deal_with_push(v);
while (!is_root(v)) {
node *u = v->p;
if (!is_root(u)) {
if ((v == u->l) ^ (u == u->p->l)) {
rotate(v);
} else {
rotate(u);
}
}
rotate(v);
}
}
void access(node *v) {
node *u = NULL;
while (v != NULL) {
splay(v);
v->r = u;
v->pull();
u = v;
v = v->p;
}
}
void make_root(node *v) {
access(v);
splay(v);
v->unsafe_reverse();
}
node* find_root(node *v) {
access(v);
splay(v);
while (v->l != NULL) {
v->push();
v = v->l;
}
splay(v);
return v;
}
void link(node *v, node *u) {
if (find_root(v) != find_root(u)) {
make_root(v);
v->p = u;
}
}
void cut(node *v, node *u) {
make_root(v);
if (find_root(u) == v && u->p == v && u->l == NULL) {
u->p = v->r = NULL;
v->pull();
}
}
void split(node *v, node *u) {
make_root(v);
access(u);
splay(u);
}
int main() {
//freopen("data.in", "r", stdin);
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int n, m;
cin >> n >> m;
vector<node*> t(n + 1);
for (int i = 1; i <= n; i++) {
t[i] = new node();
}
for (int i = 1; i < n; i++) {
int x, y;
cin >> x >> y;
link(t[x], t[y]);
}
while (m--) {
string opt;
int x, y;
cin >> opt >> x >> y;
if (opt[0] == '+') {
int c;
cin >> c;
split(t[x], t[y]);
t[y]->unsafe_add(c);
}
if (opt[0] == '-') {
int u, v;
cin >> u >> v;
cut(t[x], t[y]);
link(t[u], t[v]);
}
if (opt[0] == '/') {
split(t[x], t[y]);
cout << t[y]->sum % mod << endl;
}
if (opt[0] == '*') {
int c;
cin >> c;
split(t[x], t[y]);
t[y]->unsafe_mul(c);
}
}
for (int i = 1; i <= n; i++) {
delete t[i];
}
return 0;
}