Luogu P3313 [SDOI2014]旅行
题目链接 \(Click\) \(Here\)
假的主席树。。。实际上就是树链剖分维护十万棵动态开点的线段树,写起来和主席树比较相似。
真的是不能思维僵化啊,写题的时候一定要保证思考方向灵活\(+\)正确。
#include <bits/stdc++.h>
using namespace std;
const int N = 100010;
int tot, rt[N];
struct Segment_Node {
int ls, rs, sum, max;
void Init () {
ls = rs = sum = max = 0;
}
}t[N << 5];
#define mid ((l + r) >> 1)
void push_up (int p) {
t[p].sum = t[t[p].ls].sum + t[t[p].rs].sum;
t[p].max = max (t[t[p].ls].max, t[t[p].rs].max);
}
void modify (int &rt, int l, int r, int pos, int val) {
if (rt == 0) rt = ++tot, t[rt].Init ();
//printf ("rt = %d, l = %d, r = %d, pos = %d, val = %d\n", rt, l, r, pos, val);
if (l != r) {
if (pos <= mid) {
modify (t[rt].ls, l, mid, pos, val);
} else {
modify (t[rt].rs, mid + 1, r, pos, val);
}
push_up (rt);
} else {
t[rt].sum = t[rt].max = val;
}
//printf ("t[%d].sum = %d, t[%d].max = %d\n", rt, t[rt].sum, rt, t[rt].max);
}
int query_sum (int l, int r, int nl, int nr, int p) {
if (p == 0) return 0;
if (nl <= l && r <= nr) {
return t[p].sum;
}
int res = 0;
if (nl <= mid) res += query_sum (l, mid, nl, nr, t[p].ls);
if (mid < nr) res += query_sum (mid + 1, r, nl, nr, t[p].rs);
return res;
}
int query_max (int l, int r, int nl, int nr, int p) {
if (p == 0) return 0;
if (nl <= l && r <= nr) {
return t[p].max;
}
int res = 0;
if (nl <= mid) res = max (res, query_max (l, mid, nl, nr, t[p].ls));
if (mid < nr) res = max (res, query_max (mid + 1, r, nl, nr, t[p].rs));
return res;
}
int cnt, head[N];
struct edge {
int nxt, to;
}e[N << 1];
void add_len (int u, int v) {
e[++cnt] = (edge) {head[u], v}, head[u] = cnt;
e[++cnt] = (edge) {head[v], u}, head[v] = cnt;
}
int sz[N], pre[N], son[N], deep[N];
void dfs1 (int u, int fa) {
sz[u] = 1;
pre[u] = fa;
deep[u] = deep[fa] + 1;
int max_sz = 0;
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (v != fa) {
dfs1 (v, u);
sz[u] += sz[v];
if (sz[v] > max_sz) {
son[u] = v;
max_sz = sz[v];
}
}
}
}
int n, m, rel[N], val[N];
int dfn[N], top[N], nodw[N];
void dfs2 (int u, int tp) {
top[u] = tp;
dfn[u] = ++dfn[0];
//printf ("dfn[%d] = %d\n", u, dfn[0]);
nodw[dfn[u]] = val[u];
if (!son[u]) return;
dfs2 (son[u], tp);
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (v != pre[u] && v != son[u]) {
dfs2 (v, v);
}
}
}
int get_sum (int u, int v, int _rel) {
int res = 0;
while (top[u] != top[v]) {
//printf ("u = %d, v = %d, top[u] = %d, top[v] = %d, rel = %d\n", u, v, top[u], top[v], _rel);
if (deep[top[u]] > deep[top[v]]) {
//printf ("query_sum (%d, %d, %d, %d, %d) = %d\n", 1, n, dfn[top[u]], dfn[u], rt[_rel], query_sum (1, n, dfn[top[u]], dfn[u], rt[_rel]));
res += query_sum (1, n, dfn[top[u]], dfn[u], rt[_rel]), u = pre[top[u]];
} else {
//printf ("query_sum (%d, %d, %d, %d, %d) = %d\n", 1, n, dfn[top[v]], dfn[v], rt[_rel], query_sum (1, n, dfn[top[v]], dfn[v], rt[_rel]));
res += query_sum (1, n, dfn[top[v]], dfn[v], rt[_rel]), v = pre[top[v]];
}
}
if (deep[u] > deep[v]) swap (u, v);
res += query_sum (1, n, dfn[u], dfn[v], rt[_rel]);
//printf ("u = %d, v = %d, top[u] = %d, top[v] = %d, rel = %d\n", u, v, top[u], top[v], _rel);
//printf ("query_sum (%d, %d, %d, %d, %d) = %d\n", 1, n, dfn[u], dfn[v], rt[_rel], query_sum (1, n, dfn[u], dfn[v], rt[_rel]));
return res;
}
int get_max (int u, int v, int _rel) {
int res = 0;
while (top[u] != top[v]) {
if (deep[top[u]] > deep[top[v]]) {
res = max (res, query_max (1, n, dfn[top[u]], dfn[u], rt[_rel])), u = pre[top[u]];
} else {
res = max (res, query_max (1, n, dfn[top[v]], dfn[v], rt[_rel])), v = pre[top[v]];
}
}
if (deep[u] > deep[v]) swap (u, v);
res = max (res, query_max (1, n, dfn[u], dfn[v], rt[_rel]));
return res;
}
int main () {
t[0].Init ();
cin >> n >> m;
for (int i = 1; i <= n; ++i) {
cin >> val[i] >> rel[i];
}
for (int i = 1; i <= n - 1; ++i) {
static int u, v;
cin >> u >> v;
add_len (u, v);
}
dfs1 (1, 0);
dfs2 (1, 1);
for (int u = 1; u <= n; ++u) {
//printf ("u = %d, dfn[u] = %d, sz[u] = %d, top[u] = %d, son[u] = %d, deep[u] = %d\n", u, dfn[u], sz[u], top[u], son[u], deep[u]);
//printf ("u = %d, rel[u] = %d, rt[rel[u]] = %d\n", u, rel[u], rt[rel[u]]);
modify (rt[rel[u]], 1, n, dfn[u], val[u]);
//printf ("rt[rel[u]] = %d\n", rt[rel[u]]);
}
for (int i = 1; i <= m; ++i) {
static char opt[2];
static int x, y;
cin >> opt >> x >> y;
if (opt[1] == 'C') { //点x宗教改为y
modify (rt[rel[x]], 1, n, dfn[x], 0);
rel[x] = y;
modify (rt[rel[x]], 1, n, dfn[x], val[x]);
}
if (opt[1] == 'W') { //权值改为y
val[x] = y;
modify (rt[rel[x]], 1, n, dfn[x], val[x]);
}
if (opt[1] == 'S') {
cout << get_sum (x, y, rel[y]) << endl;;
}
if (opt[1] == 'M') {
cout << get_max (x, y, rel[y]) << endl;
}
}
}