BZOJ2243. [SDOI2011]染色
树链剖分。线段树维护区间最左边和最右边的颜色以及区间颜色段,合并时等于左区间颜色段+右区间颜色段-[左区间右端点颜色==右区间左端点颜色]
#include <bits/stdc++.h> namespace IO { #define getc getchar void read() {} template <typename T, typename... T2> inline void read(T &x, T2 &... oth) { T f = 1; x = 0; char ch = getc(); while (!isdigit(ch)) { if (ch == '-') f = -1; ch = getc(); } while (isdigit(ch)) { x = x * 10 + ch - 48; ch = getc(); } x *= f; read(oth...); } } // using namespace IO #define read IO::read const int N = 1e5 + 7; int n, m, col[N], sz[N], son[N], top[N], dfn[N], tol, wt[N], fa[N]; int dep[N]; std::vector<int> vec[N]; struct Seg { #define lp p << 1 #define rp p << 1 | 1 static const int NN = N * 4; int lazy[NN], sum[NN], lc[NN], rc[NN]; inline void pushup(int p) { sum[p] = sum[lp] + sum[rp]; lc[p] = lc[lp]; rc[p] = rc[rp]; if (rc[lp] == lc[rp]) sum[p]--; } inline void tag(int p, int color) { lc[p] = rc[p] = color; sum[p] = 1; lazy[p] = color; } inline void pushdown(int p) { if (lazy[p] >= 0) { tag(lp, lazy[p]); tag(rp, lazy[p]); lazy[p] = -1; } } void build(int p, int l, int r) { lazy[p] = -1; if (l == r) { sum[p] = 1; lc[p] = rc[p] = wt[l]; return; } int mid = l + r >> 1; build(lp, l, mid); build(rp, mid + 1, r); pushup(p); } void update(int p, int l, int r, int x, int y, int c) { //if (x > r || l > y) return; if (x <= l && y >= r) { tag(p, c); return; } pushdown(p); int mid = l + r >> 1; if (x <= mid) update(lp, l, mid, x, y, c); if (y > mid) update(rp, mid + 1, r, x, y, c); pushup(p); } int query(int p, int l, int r, int x, int y) { //if (x > r || l > y) return 0; if (x <= l && y >= r) return sum[p]; pushdown(p); int mid = l + r >> 1; if (x > mid) return query(rp, mid + 1, r, x, y); if (y <= mid) return query(lp, l, mid, x, y); int ans = query(lp, l, mid, x, y) + query(rp, mid + 1, r, x, y); if (rc[lp] == lc[rp]) ans--; return ans; } int query(int p, int l, int r, int pos) { if (l == r) return lc[p]; if (lazy[p] >= 0) return lazy[p]; int mid = l + r >> 1; pushdown(p); if (pos <= mid) return query(lp, l, mid, pos); return query(rp, mid + 1, r, pos); } void print(int p, int l, int r) { if (l == r) return (void)(printf("%d ", lc[p])); pushdown(p); int mid = l + r >> 1; print(lp, l, mid); print(rp, mid + 1, r); } } seg; void dfs1(int u, int pre) { sz[u] = 1; fa[u] = pre; dep[u] = dep[pre] + 1; for (int v: vec[u]) { if (v == pre) continue; dfs1(v, u); sz[u] += sz[v]; if (sz[v] > sz[son[u]]) son[u] = v; } } void dfs2(int u, int tp) { top[u] = tp; dfn[u] = ++tol; wt[tol] = col[u]; if (!son[u]) return; dfs2(son[u], tp); for (int v: vec[u]) if (v != fa[u] && v != son[u]) dfs2(v, v); } void solve(int u, int v, int c) { while (top[u] != top[v]) { if (dep[top[u]] < dep[top[v]]) std::swap(u, v); seg.update(1, 1, n, dfn[top[u]], dfn[u], c); u = fa[top[u]]; } if (dep[u] > dep[v]) std::swap(u, v); seg.update(1, 1, n, dfn[u], dfn[v], c); } int solve(int u, int v) { int ans = 0; while (top[u] != top[v]) { if (dep[top[u]] < dep[top[v]]) std::swap(u, v); ans += seg.query(1, 1, n, dfn[top[u]], dfn[u]); if (seg.query(1, 1, n, dfn[top[u]]) == seg.query(1, 1, n, dfn[fa[top[u]]])) ans--; u = fa[top[u]]; } if (dep[u] > dep[v]) std::swap(u, v); ans += seg.query(1, 1, n, dfn[u], dfn[v]); return ans; } int main() { read(n, m); for (int i = 1; i <= n; i++) read(col[i]); for (int i = 1, u, v; i < n; i++) { read(u, v); vec[u].push_back(v); vec[v].push_back(u); } dfs1(1, 0); dfs2(1, 1); seg.build(1, 1, n); //seg.print(1, 1, n); for (int u, v, c; m--; ) { char s[10]; scanf("%s", s); read(u, v); if (s[0] == 'C') { read(c); solve(u, v, c); } else { printf("%d\n", solve(u, v)); } } return 0; }