P2486 [SDOI2011]染色
P2486染色
细节较多,线段树 \(pushup\) 时不仅要更新 \(kinds\) 还要更新 \(lco\) 与 \(rco\).
#include <set>
#include <cmath>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <assert.h>
#include <algorithm>
using namespace std;
#define fir first
#define sec second
#define pb push_back
#define mp make_pair
#define LL long long
#define INF (0x3f3f3f3f)
#define mem(a, b) memset(a, b, sizeof (a))
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define Debug(x) cout << #x << " = " << x << endl
#define travle(i, x) for (register int i = head[x]; i; i = nxt[i])
#define For(i, a, b) for (register int (i) = (a); (i) <= (b); ++ (i))
#define Forr(i, a, b) for (register int (i) = (a); (i) >= (b); -- (i))
#define file(s) freopen(s".in", "r", stdin), freopen(s".out", "w", stdout)
namespace io {
static char buf[1<<21], *pos = buf, *end = buf;
inline char getc()
{ return pos == end && (end = (pos = buf) + fread(buf, 1, 1<<21, stdin), pos == end) ? EOF : *pos ++; }
inline int rint() {
register int x = 0, f = 1;register char c;
while (!isdigit(c = getc())) if (c == '-') f = -1;
while (x = (x << 1) + (x << 3) + (c ^ 48), isdigit(c = getc()));
return x * f;
}
inline LL rLL() {
register LL x = 0, f = 1; register char c;
while (!isdigit(c = getc())) if (c == '-') f = -1;
while (x = (x << 1ll) + (x << 3ll) + (c ^ 48), isdigit(c = getc()));
return x * f;
}
inline void rstr(char *str) {
while (isspace(*str = getc()));
while (!isspace(*++str = getc()))
if (*str == EOF) break;
*str = '\0';
}
template<typename T>
inline bool chkmin(T &x, T y) { return x > y ? (x = y, 1) : 0; }
template<typename T>
inline bool chkmax(T &x, T y) { return x < y ? (x = y, 1) : 0; }
}
using namespace io;
const int N = 1e5 + 1;
int n, m, a[N];
int tot, head[N], ver[N<<1], nxt[N<<1];
inline void add(int u, int v)
{ ver[++tot] = v, nxt[tot] = head[u], head[u] = tot; }
int cnt, seg[N], top[N], dep[N], fa[N], son[N], size[N], rev[N];
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
int val[N<<2], lco[N<<2], rco[N<<2], kinds[N<<2], tag[N];
void pushup(int x) {
kinds[x] = kinds[ls(x)] + kinds[rs(x)] - (rco[ls(x)] == lco[rs(x)]);
lco[x] = lco[ls(x)]; rco[x] = rco[rs(x)];
}
void build(int x, int l, int r) {
if (l == r) {
val[x] = lco[x] = rco[x] = a[rev[l]];
kinds[x] = 1;
return;
}
int mid = l + r >> 1;
build(ls(x), l, mid);
build(rs(x), mid + 1, r);
pushup(x);
}
void pushdown(int x) {
if (~tag[x]) {
tag[ls(x)] = lco[ls(x)] = rco[ls(x)] = tag[x]; kinds[ls(x)] = 1;
tag[rs(x)] = lco[rs(x)] = rco[rs(x)] = tag[x]; kinds[rs(x)] = 1;
tag[x] = -1;
}
}
void change(int x, int l, int r, int L, int R, int val) {
if (L <= l && r <= R) {
tag[x] = lco[x] = rco[x] = val;
kinds[x] = 1;
return;
}
pushdown(x);
int mid = l + r >> 1;
if (L <= mid) change(ls(x), l, mid, L, R, val);
if (R > mid) change(rs(x), mid + 1, r, L, R, val);
pushup(x);
}
int query(int x, int l, int r, int L, int R, int &a, int &b) {
if (L <= l && r <= R) {
if (l == L) a = lco[x];
if (r == R) b = rco[x];
return kinds[x];
}
pushdown(x);
int Lres = 0, Rres = 0, mid = l + r >> 1;
if (L <= mid) Lres = query(ls(x), l, mid, L, R, a, b);
if (mid < R) Rres = query(rs(x), mid + 1, r, L, R, a, b);
if (!Lres or !Rres) return Lres + Rres;
return Lres + Rres - (rco[ls(x)] == lco[rs(x)]);
}
void Modify(int x, int y, int val) {
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) swap(x, y);
change(1, 1, n, seg[top[x]], seg[x], val);
x = fa[top[x]];
}
if (dep[x] < dep[y]) swap(x, y);
change(1, 1, n, seg[y], seg[x], val);
}
int Query(int x, int y) {
int coll[2], colr[2], last0, last1, ans = 0;
coll[0] = coll[1] = colr[0] = colr[1] = last0 = last1 = -1;
while (top[x] != top[y]) {
if (dep[top[x]] > dep[top[y]]) {
ans += query(1, 1, n, seg[top[x]], seg[x], coll[0], colr[0]);
ans -= colr[0] == last0;
last0 = coll[0];
x = fa[top[x]];
} else {
ans += query(1, 1, n, seg[top[y]], seg[y], coll[1], colr[1]);
ans -= colr[1] == last1;
last1 = coll[1];
y = fa[top[y]];
}
}
if (dep[x] > dep[y]) swap(x, y), swap(last1, last0);
ans += query(1, 1, n, seg[x], seg[y], coll[1], colr[1]);
ans -= (coll[1] == last0) + (colr[1] == last1);
return ans ;
}
void DFS1(int u, int f) {
fa[u] = f;
dep[u] = dep[f] + 1;
size[u] = 1;
travle(i, u) {
if (ver[i] != f) {
DFS1(ver[i], u);
size[u] += size[ver[i]];
if (size[son[u]] < size[ver[i]]) son[u] = ver[i];
}
}
}
void DFS2(int u, int topf) {
top[u] = topf;
seg[u] = ++cnt;
rev[cnt] = u;
if (son[u]) DFS2(son[u], topf);
travle(i, u) if (ver[i] != fa[u] && ver[i] != son[u])
DFS2(ver[i], ver[i]);
}
int main() {
#ifndef ONLINE_JUDGE
file("P2468");
#endif
memset(tag, -1, sizeof tag);
n = rint(), m = rint();
For(i, 1, n) a[i] = rint();
For(i, 1, n -1 ) {
int u = rint(), v = rint();
add(u, v); add(v, u);
}
DFS1(1, 0);
DFS2(1, 1);
build(1, 1, n);
char op[30];
while (m --) {
rstr(op);
if (op[0] == 'C') {
int u = rint(), v = rint(), val = rint();
Modify(u, v, val);
} else {
int u = rint(), v = rint();
printf("%d\n", Query(u, v));
}
// cout << "ok " << query(1, 1, n, seg[5], seg[5], cnt, cnt) << endl;
}
}